Introducing TorchMultimodal - a library for accelerating exploration in Multimodal AI
by Kartikay Khandelwal, Ankita De
We are announcing TorchMultimodal Beta, a PyTorch domain library for training SoTA multi-task multimodal models at scale. The library provides composable building blocks (modules, transforms, loss functions) to accelerate model development, SoTA model architectures (FLAVA, MDETR, Omnivore) from published research, training and evaluation scripts, as well as notebooks for exploring these models. The library is under active development, and we’d love to hear your feedback! You can find more details on how to get started here .
Why TorchMultimodal?
Interest is rising around AI models that understand multiple input types (text, images, videos and audio signals), and optionally use this understanding to generate different forms of outputs (sentences, pictures, videos). Recent work from FAIR such as FLAVA , Omnivore and data2vec have shown that multimodal models for understanding are competitive with unimodal counterparts, and in some cases are establishing the new state-of-the art. Generative models such as Make-a-video and Make-a-scene are redefining what modern AI systems can do.
As interest in multimodal AI has grown, researchers are looking for tools and libraries to quickly experiment with ideas, and build on top of the latest research in the field. While the PyTorch ecosystem has a rich repository of libraries and frameworks, it’s not always obvious how components from these interoperate with each other, or how they can be stitched together to build SoTA multimodal models.
TorchMultimodal solves this problem by providing:
Composable and easy-to-use building blocks which researchers can use to accelerate model development and experimentation in their own workflows. These are designed to be modular, and can be easily extended to handle new modalities.
End-to-end examples for training and evaluating the latest models from research. These should serve as starting points for ongoing/future research, as well as examples for using advanced features such as integrating with FSDP and activation checkpointing for scaling up model and batch sizes.
Introducing TorchMultimodal
TorchMultimodal is a PyTorch domain library for training multi-task multimodal models at scale. In the repository, we provide:
Building Blocks . A collection of modular and composable building blocks like models, fusion layers, loss functions, datasets and utilities. Some examples include:
Contrastive Loss with Temperature . Commonly used function for training models like CLIP and FLAVA. We also include variants such as ImageTextContrastiveLoss used in models like ALBEF.
Codebook layers which compresses high dimensional data by nearest neighbor lookup in an embedding space and is a vital component of VQVAEs (provided as a model in the repository).
Shifted-window Attention window based multi-head self attention which is a vital component of encoders like Swin 3D Transformers.
Components for CLIP. A popular model published by OpenAI which has proven to be extremely effective at learning text and image representations.
Multimodal GPT. An abstraction that extends OpenAI’s GPT architecture for multimodal generation when combined with the generation utility .
MultiHeadAttention . A critical component for attention-based models with support for fast auto-regressive decoding.
Examples . A collection of examples that show how to combine these building blocks with components and common infrastructure (Lightning, TorchMetrics) from across the PyTorch Ecosystem to replicate state-of-the-art models published in literature. We currently provide five examples, which include.
FLAVA [ paper ]. Official code for the paper accepted at CVPR, including a tutorial on finetuning FLAVA.
MDETR [ paper ]. Collaboration with authors from NYU to provide an example which alleviates interoperability pain points in the PyTorch ecosystem, including a notebook on using MDETR for phrase grounding and visual question answering.
Omnivore [ paper ]. First example in TorchMultimodal of a model which deals with Video and 3D data, including a notebook for exploring the model.
MUGEN [ paper ]. Foundational work for auto-regressive generation and retrieval , including demos for text-video generation and retrieval with a large-scale synthetic dataset enriched from OpenAI coinrun .
ALBEF [ paper ] Code for the model, including a notebook for using this model for Visual Question Answering.
The following code snippet showcases an example usage of several TorchMultimodal components related to CLIP:
# instantiate clip transform clip_transform = CLIPTransform() # pass the transform to your dataset. Here we use coco captions dataset = CocoCaptions(root= ..., annFile=..., transforms=clip_transform) dataloader = DataLoader(dataset, batch_size=16) # instantiate model. Here we use clip with vit-L as the image encoder model= clip_vit_l14() # define loss and other things needed for training clip_loss = ContrastiveLossWithTemperature() optim = torch.optim.AdamW(model.parameters(), lr = 1e-5) epochs = 1 # write your train loop for _ in range(epochs): for batch_idx, batch in enumerate(dataloader): image, text = batch image_embeddings, text_embeddings = model(image, text) loss = contrastive_loss_with_temperature(image_embeddings, text_embeddings) loss.backward() optimizer.step()
Apart from the code, we are also releasing a tutorial for fine-tuning multimodal foundation models, and a blog post (with code pointers) on how to scale up such models using techniques from PyTorch Distributed (FSDP and activation checkpointing). We hope such examples and tutorials will serve to demystify a number of advanced features available in the PyTorch ecosystem.
What’s Next?
While this is an exciting launch, there’s a lot more to come. The library is under development and we are working on adding some of the exciting developments in the space of diffusion models, and examples to showcase common trends from research. As you explore and use the library, we’d love to hear any feedback you might have! You can find more details on how to get started here .
Team
The primary contributors and developers of TorchMultimodal include Ankita De, Evan Smothers, Kartikay Khandelwal, Lan Gong, Laurence Rouesnel, Nahiyan Malik, Rafi Ayub and Yosua Michael Maranatha.
Docs