Understanding CLIP image pipeline
Visualize image preprocessing steps and builds a pytorch custom Dataset to map image into embedding
This post builds on top of the my flicker 8k exploratory analysis post.
import os
# sample images to use in this post
relevant_image_names = [
os.environ["IMG_DIR"] = "data/2021-10-22-understanding-clip-image-pipeline"
import clip
model, preprocess = clip.load("ViT-B/32")
We can take a look into the image pre-processing pipeline.
The CLIP preprocessing pipeline assumes we have a PIL image as input, so that is what we will use to load images here.
from PIL import Image
images = [Image.open(os.path.join(os.environ["IMG_DIR"], image_file_name)) for image_file_name in relevant_image_names]
import matplotlib.pyplot as plt
def plot_pil_images(images):
assert len(images) == 4, "Number of images should be equal to 4"
fig = plt.figure(figsize=(10, 10))
for idx, image in enumerate(images):
sub = fig.add_subplot(2,2,idx+1)
imgplot = plt.imshow(image)
Resize the image so that the smaller of height and width have size 224.
processed_images = [preprocess.transforms[0](image) for image in images]
The effect is hard to note by just looking at the images as the proportion of the images continues to be the same, but it is easy to see when we print the size of the tensors below.
from torchvision.transforms import ToTensor
for idx, (original_image, processed_image) in enumerate(zip(images, processed_images)):
print("Image {}:\nOriginal size: {}\nProcessed size:{}\n".format(idx+1, ToTensor()(original_image).shape, ToTensor()(processed_image).shape))
Crop the center of the image such that the resulting images has size (224, 224)
processed_images = [preprocess.transforms[1](image) for image in processed_images]
This transform convert the image to RGB model. Since the images are already formatted with RGB encoding we will not see any difference.
from torch import equal
equal(ToTensor()(preprocess.transforms[2](processed_images[0])), ToTensor()(processed_images[0]))
But we will apply this transform for completeness:
processed_images = [preprocess.transforms[2](image) for image in processed_images]
processed_images = [preprocess.transforms[3](image) for image in processed_images]
We now have tensors:
print([x.shape for x in processed_images])
Normalize will subtract the mean and divide by the standard deviation. The is one mean and one standard deviation for each of the three channels available.
processed_images = [preprocess.transforms[4](image) for image in processed_images]
After the normalization the processed images are quite different from the original images.
from torchvision.transforms import ToPILImage
plot_pil_images([ToPILImage()(x) for x in processed_images])
Stack the four processed images on top of each other.
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
image_input = torch.tensor(np.stack(processed_images)).to(device)
Generate the image vectors. Each vector have 512 elements.
with torch.no_grad():
image_features = model.encode_image(image_input).float()
Use unsqueeze(0)
if you want to apply the encoder to one image as it expected a 4D Tensor.
with torch.no_grad():
image_features = model.encode_image(processed_images[0].unsqueeze(0))
Create a custom Dataset that loads an image and optionally apply a transform function to it.
import os
import glob
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.image_file_names = glob.glob(os.path.join(img_dir, "*.jpg"))
self.transform = transform
def __len__(self):
return len(self.image_file_names)
def __getitem__(self, idx):
image_file_name = self.image_file_names[idx]
image = Image.open(image_file_name)
if self.transform:
image = self.transform(image)
return image
Here is how it looks without applying and transform
function to it:
image_dataset = ImageDataset(img_dir=os.environ["IMG_DIR"])
Each iteration returns one of the original images:
image_dataset = ImageDataset(img_dir=os.environ["IMG_DIR"], transform=preprocess)
It returns the processed Tensor when we apply the preprocess pipeline:
def from_image_to_vector(x, process_fn):
with torch.no_grad():
image_features = model.encode_image(process_fn(x).unsqueeze(0))
return image_features
image_dataset = ImageDataset(
transform=lambda x: from_image_to_vector(x, process_fn=preprocess)
Now the dataset returns a 512-dimensional vector associated with a specific image.
from torch.utils.data import DataLoader
image_dataloader = DataLoader(image_dataset, batch_size=64, shuffle=False)