IMAGEBIND: One Embedding Space To Bind Them All

Souvik Mandal
ITNEXT
Published in
8 min readDec 4, 2023

--

IMAGEBIND is an approach to learning joint embeddings across six different modalities: image, text, audio, depth, thermal, and IMU data.

The main contribution of this paper is that a combination of paired datasets is not required to train multi-modality models if only one of the common modalities is sufficient to bind all the modalities together. In this paper, they have aligned all modality's embedding to image embeddings.

This is the first model capable of binding information from six modalities.

IMAGEBIND

Normally, contrastive learning is used to align the embedding space of two different modalities of related examples. Let’s say we have image and text modality. We can patchify the image and pass it through a ViT, and we can take models like BERT and pass the text to get the image and text token embeddings.

We then take the class token because it will have global information about the modality. Next, we pass it through an MLP layer (projection layer) to make the class token dimension across modalities of the same dimension.

Image Text alignment

Since IMAGEBIND has six modalities, we will need six different encoders to process each type of data.

Also, previous multi-modality models mostly consisted of two modalities. The majority of them were vision-language models with image and text data. So, they were trained with paired image and text (image caption) data. But with six modalities if we look for one data point with all six modality information, very little open-source data will be available, and creating that dataset will be time-consuming.

IMAGEBIND can leverage recent large scale vision-language models, and extends their zeroshot capabilities to new modalities just by using their natural pairing with images

We create the dataset with each data point consisting of an image and one of the considered modalities. For example, the first data point can be (image, text), the second one (image, video), and the third one is (image, audio) pairs. For all data points, the imaging modality is common, and the image encoder is being trained. Text encoder will not be accumulating gradient in case the datapoint is (image, video) or (image, audio) and so on.

Data processing in case we have (image, text), (image, video), (image, audio) pairs

Now, since the image modality is common across all data points, the only way the model can converge is if all the modalities are aligned with the image embeddings.

What will happen if the model does not align all modality embeddings into image embeddings?

Let's say we train with (image, text) pair, and the text/image embeddings are not aligned in the image embedding space. Instead, the image and text embeddings are aligned in text embedding space.

Then next time a different modality pair is passed eg: (image, audio) the image encoder will embed the representation in text embedding space. The audio data is never trained with text data, the audio encoder can embed either in the image or audio embedding space. So the loss with increase when training on any other modality like (image, audio). The only way to decrease this loss is to align everything into the common modality (image) embedding space.

Note: Image and Video are considered a single modality in the paper and processed in the same ViT encoder. In this explanation, I will consider two different modalities. This is just easier for me to explain video as a different modality than audio, depth, thermal, and IMU data. You can think of video data, video encoder, and video pre-processing as audio data, audio encoder, and audio pre-processing or one of the other modalities. So actually the data points can be (image/video, text) or (image/video, audio), (image/video, depth), and so on.

Downstream Tasks

Cross modality retrival

We can generate the feature vector (after the projection layer) in one modality and can search against the feature vectors of other modalities.

Embedding-Space operations-based retrieval

Adding embedding of different modalities seems to generate the embedding with both types of features. For example, summing a bird in forest image embeddings with waves sound embedding will generate embeddings which is similar to a bird with waves in background image embeddings.

Audio to image generation

Using audio embeddings with a pre-trained DALLE-2 decoder designed to work with CLIP text embeddings.

Zero-Shot Classification

We can do a zero-shot classification of one modality with prompts from another modality. For example, just by training on (image, text) and (image, audio), IMAGEBIND can perform zero-shot classification of audio using text prompts.

Results

Note: For prior zero-shot classification, we need direct paired data. For example, with the CLIP model, we can do zero-shot image classification. But with IMAGEBIND we don't need direct paired data to do zero-shot classification. For example, (image, text), (image, audio) pairs are trained together but there is no (text, audio) direct pair. Now if we do zero-shot classification of audio using text prompts, the authors are calling them Emergent Zero-Shot Classification.

Few-shot classification on audio and depth: A linear classifier on the fixed features for the ≥ 1-shot case.

IMAGEBIND outperforms the self-supervised AudioMAE model. IMAGEBIND even outperforms a supervised AudioMAE model upto 4 shot learning showing its strong generalization.

We compare with the MultiMAE model trained with images, depth, and semantic segmentation masks. IMAGEBIND outperforms MultiMAE across all few-shot settings on few-shot depth classification.

Few-shot classification on audio and dept

Text Retrieval: MSR-VTT is a large-scale dataset for open-domain video captioning. The results in the table show text retrieval recall with Video (V) or Audio (A) or both (A+V) as prompts for different models.

text retrieval performance using audio and video

Audio retrieval:

Prior work trains using paired data for that modality, e.g., AudioCLIP uses (audio, text) supervision and AVFIC uses automatically mined (audio, text) pairs.

zero-shot audio retrieval and classification.

Implementation

Here, I will show the implementation for three modalities (image, video, text) but this can be extended for other modalities similar way.

Let's say we have a batch consisting of two data points, [(img1, text), (img2, video)].

import torch
import timm
import torch.nn as nn
import numpy as np

from transformers import BertTokenizer, BertModel

# Lets create a (image, caption/text) pair
img1 = torch.randn(224, 224, 3)
text = "open source is good"

# Lets create a (image, video) pair
img2 = torch.randn(224, 224, 3)
video = torch.randn(60, 224, 224, 3) # 2 sec video with 30 FPS

Next, let's define the three encoders. For the image encoder, I will use Timm, for the video encoder I will use the vit-pytorch repo, and for the text encoder, I will use the huggingface.

# Insall the packages if not installed
!pip install timm -q
! pip install vit-pytorch -q

First, define the image encoder. This is a ViT Base model which will process images of size 224,224,3. It will create patches with size 16*16 and the embedding dimension is 768. Also, let's say finally we want all the encodings to be a 512-dimensional vector. So we will define a projection layer also.

# lets initialize the image encoder to process the image data
img_encoder = timm.create_model("vit_base_patch16_224", pretrained=True)
img_proj = nn.Linear(768, 512)

For the text encoder, we will need to initialize a tokenizer and a transformer backbone. We will use a BPE tokenizer and pretrained BERT model in this example. Finally, we will need a projection layer.

# tokenizer and text encoder
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def text_tokenize(text):
# Add the class token, then tokenize a caption
return tokenizer.encode(f"[CLS] {text}", add_special_tokens=True)

model_name = 'bert-base-uncased'
model = BertModel.from_pretrained(model_name)
text_proj = nn.Linear(768, 512)

Similar to the image encoder, for video we will use a ViT model. I am not explaining how 3D volumes are processed in the transformer, you can check this blog for details.

# video encoder: This is also a ViT
from vit_pytorch.vit_3d import ViT

vit_vid = ViT(
image_size = 224, # image size
frames = 60, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 15, # frame patch size
num_classes = 0,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
vit_vid.mlp_head = nn.Identity() # No need for the classification head
video_proj = nn.Linear(1024, 512)

Let's do the forward pass for the (image, text) data

# Process the image text pair
## Process the image
img1_embeddings = img_encoder.forward_features(img1.permute(2, 0, 1).unsqueeze(0)) # [1, 197, 768]
img1_cls_token = img1_embeddings[:, 0, ...] # [1, 768]
img1_proj = img_proj(img1_cls_token) # [1, 512]
## Process the text
input_ids = text_tokenize(text) # [101, 101, 2330, 3120, 2003, 2204, 102]
# Convert input_ids to PyTorch tensor
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension, [1, 7]

# Forward pass through the BERT model
outputs = model(input_ids_tensor)
last_hidden_states = outputs.last_hidden_state
text_cls_token = last_hidden_states[:, 0, ...] # [1, 768]
text_cls_proj = text_proj(text_cls_token) # [1, 512]

Next let's process the (image, video) pair data.

# process the image-video pair
## Process the video
video_cls_token = vit_vid(video.permute(3, 0, 1, 2).unsqueeze(0)) # B, Channels, Frames, Width, Height
video_cls_proj = video_proj(video_cls_token)
## Process the image
img2_embeddings = img_encoder.forward_features(img2.permute(2, 0, 1).unsqueeze(0)) # [1, 197, 768]
img2_cls_token = img2_embeddings[:, 0, ...] # [1, 768]
img2_proj = img_proj(img2_cls_token) # [1, 512]

Create the batch by stacking the data together.

# batch: [(img1, text), (img2, video)]
img_embeddings = torch.vstack([img1_proj, img2_proj]) # [2, 512]
second_modality_embeddings = torch.vstack([text_cls_proj, video_cls_proj]) # [2, 512]

Next, we need to compute the loss. First, let’s normalize features.

img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)
second_modality_embeddings = second_modality_embeddings / second_modality_embeddings.norm(dim=1, keepdim=True)

Compute the cosine similarity of all the image, and text pairs and scale the similarity:

cosine_sim = img_embeddings @ second_modality_embeddings.t()

logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logit_scale = logit_scale.exp() # log scale, learned during training

logits_per_image = logit_scale * cosine_sim
logits_per_text = logits_per_image.t()

Create the GT labels. So each datapoint image should have the highest similarity with the corresponding second modality.

labels = torch.arange(logits_per_image.shape[0], dtype=torch.long)
labels # [0, 1]

Compute the loss

total_loss = (
torch.nn.functional.cross_entropy(logits_per_image, labels) +
torch.nn.functional.cross_entropy(logits_per_text, labels)
) / 2

Resources

Connect with me

Feel free to drop me a message or

  1. Connect and reach out on LinkedIn
  2. Follow me on Medium or GitHub

Have a nice day ❤️

--

--