Compare pre-trained CLIP models for text-image retrieval
Create, deploy, feed and evaluate the Vespa app using the Vespa python API
There are multiple CLIP model variations
import clip
clip.available_models()
Create a PyTorch dataset that loads an image, create a CLIP-based embedding and output the data into a pyvespa-compatible feed format to make it easier to feed the entire dataset into the search app that will be created below.
import glob
import ntpath
import torch
from torch.utils.data import Dataset
from PIL import Image
def translate_model_names_to_valid_vespa_field_names(model_name):
return model_name.replace("/", "_").replace("-", "_").lower()
class ImageFeedDataset(Dataset):
def __init__(self, img_dir, model_name):
valid_vespa_model_name = translate_model_names_to_valid_vespa_field_names(model_name)
self.model, self.preprocess = clip.load(model_name)
self.img_dir = img_dir
self.image_file_names = glob.glob(os.path.join(img_dir, "*.jpg"))
self.image_embedding_name = valid_vespa_model_name + "_image"
def _from_image_to_vector(self, x):
with torch.no_grad():
image_features = self.model.encode_image(self.preprocess(x).unsqueeze(0)).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def __len__(self):
return len(self.image_file_names)
def __getitem__(self, idx):
image_file_name = self.image_file_names[idx]
image = Image.open(image_file_name)
image = self._from_image_to_vector(image)
image_base_name = ntpath.basename(image_file_name)
return {
"id": image_base_name.split(".jpg")[0],
"fields": {
"image_file_name": image_base_name,
self.image_embedding_name: {"values": image.tolist()[0]}
},
"create": True
}
def get_embedding_size(self):
return len(self.__getitem__(0)["fields"][image_dataset.image_embedding_name]["values"])
We need a text processor to map text to embedding when querying our search app
import clip
import torch
class TextProcessor(object):
def __init__(self, model_name):
self.model, _ = clip.load(model_name)
self.model_name = model_name
def embed(self, text):
text_tokens = clip.tokenize(text)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.tolist()[0]
We need to know embedding sizes when creating our search app. Note that each model variation has a different embedding size.
model_info = {}
for model_name in clip.available_models():
image_dataset = ImageFeedDataset(
img_dir=os.environ["IMG_DIR"], # Folder containing image files
model_name=model_name # CLIP model name used to convert image into vector
)
embedding_size = image_dataset.get_embedding_size()
model_info[translate_model_names_to_valid_vespa_field_names(model_name)] = embedding_size
print(model_info)
from vespa.package import ApplicationPackage, Field, HNSW, RankProfile, QueryTypeField
def create_text_image_app(model_info):
"""
Create text to image search app based on a variety of CLIP models
:param model_info: dict containing (vespa compatible) model names as keys and embedding size as values.
Check `clip.available_models()` to check which models are available.
:return: A Vespa application package.
"""
app_package = ApplicationPackage(name="image_search")
app_package.schema.add_fields(
Field(
name="image_file_name",
type="string",
indexing=["summary", "attribute"]
),
)
for model_name, embedding_size in model_info.items():
app_package.schema.add_fields(
Field(
name=model_name + "_image",
type="tensor<float>(x[{}])".format(embedding_size),
indexing=["attribute", "index"],
ann=HNSW(
distance_metric="euclidean",
max_links_per_node=16,
neighbors_to_explore_at_insert=500
)
)
)
app_package.schema.add_rank_profile(
RankProfile(
name=model_name + "_similarity",
inherits="default",
first_phase="closeness({})".format(model_name + "_image")
)
)
app_package.query_profile_type.add_fields(
QueryTypeField(
name="ranking.features.query({})".format(model_name + "_text"),
type="tensor<float>(x[{}])".format(embedding_size)
)
)
return app_package
app_package = create_text_image_app(model_info)
from vespa.deployment import VespaCloud
vespa_cloud = VespaCloud(
tenant="vespa-team",
application="pyvespa-integration",
key_location=os.environ["USER_KEY_PATH"],
application_package=app_package,
)
app = vespa_cloud.deploy(
instance="clip-image-search", disk_folder=os.environ["DISK_FOLDER"]
)
import time
from aiohttp.client_exceptions import ClientConnectorError
from asyncio import TimeoutError
from torch.utils.data import DataLoader
# This is for demo purpose as this step should be run outside a notebook on a multi-processing environment.
for model_name in clip.available_models():
image_dataset = ImageFeedDataset(
img_dir=os.environ["IMG_DIR"], # Folder containing image files
model_name=model_name # CLIP model name used to convert image into vector
)
dataloader = DataLoader(image_dataset, batch_size=128, shuffle=False, collate_fn=lambda x: x)
for idx, batch in enumerate(dataloader):
responses = None
while responses is None:
try:
responses = app.update_batch(batch=batch)
except (ClientConnectorError, TimeoutError):
time.sleep(3)
print("Model name: {}. Iteration: {}/{}".format(model_name, idx, len(dataloader)))
print("Status code summary: {}".format(Counter([x.status_code for x in responses])))
Define search evaluation metrics:
from vespa.evaluation import MatchRatio, Recall, ReciprocalRank
eval_metrics = [
MatchRatio(),
Recall(at=5),
Recall(at=100),
ReciprocalRank(at=5),
ReciprocalRank(at=100)
]
Create a functions that takes query
and returns the body of a query request based on the Vespa Query Language.
from vespa.query import QueryModel
def create_vespa_query(query, text_processor):
valid_vespa_model_name = translate_model_names_to_valid_vespa_field_names(text_processor.model_name)
image_field_name = valid_vespa_model_name + "_image"
text_field_name = valid_vespa_model_name + "_text"
ranking_name = valid_vespa_model_name + "_similarity"
return {
'yql': 'select * from sources * where ([{{"targetNumHits":100}}]nearestNeighbor({},{}));'.format(
image_field_name,
text_field_name
),
'hits': 100,
'ranking.features.query({})'.format(text_field_name): text_processor.embed(query),
'ranking.profile': ranking_name,
'timeout': 10
}
def create_body_function(model_name):
text_processor = TextProcessor(model_name=model_name)
return lambda x: create_vespa_query(x, text_processor=text_processor)
Create one QueryModel
for each of the CLIP models
query_models = []
for model_name in clip.available_models():
query_models.append(
QueryModel(
name=model_name,
body_function=create_body_function(model_name)
)
)
Load labeled data.
from pandas import read_csv
labeled_data = read_csv("/Users/tmartins/projects/data/flickr8k/labeled_data.csv", sep = "\t")
labeled_data.head()
Evaluate the application and return per query results.
result = app.evaluate(
labeled_data=labeled_data,
eval_metrics=eval_metrics,
query_model=query_models,
id_field="image_file_name",
per_query=True
)
result.head()
Visualize RR@100:
import plotly.express as px
fig = px.box(result, x="model", y="reciprocal_rank_100")
fig.show()
Compute mean and median across models:
result[["model", "reciprocal_rank_100"]].groupby(
"model"
).agg(
Mean=('reciprocal_rank_100', 'mean'),
Median=('reciprocal_rank_100', 'median')
)