论文:Learning Transferable Visual Models From Natural Language Supervision
地址:Learning Transferable Visual Models From Natural Language Supervision
一、关于CLIP
基于图文匹配的特征学习:该论文证明了预测哪个标题与哪个图像相匹配的简单预训练任务是一种有效且可扩展的方法,可以在从互联网收集的4亿对(图像,文本)数据集上从头开始学习SOTA图像表示。在预训练之后,使用自然语言来参考学习到的视觉概念(或描述新的概念),从而实现模型向下游任务的zero-shot迁移学习。
怎么做?传统模型联合训练一个图像特征提取器和一个线性分类器来预测一些标签,CLIP联合训练一个图像编码器和一个文本编码器来预测一批(图像、文本)训练样本的正确配对。 在测试时,学习的文本编码器通过嵌入目标数据集的类的名称或描述来合成zero-shot线性分类器。
-
图像编码器 (Image Encoder):它的任务是把任何一张图片变成一串模型能理解的特征向量。这串向量浓缩了图片的核心信息。
-
文本编码器 (Text Encoder):它的任务是把任何一段文字描述也变成一串“特征向量”。这串向量浓缩了文字的核心含义。
-
理论上,如果图片和文字是匹配的,那么它们被编码器转换成的“特征向量”就应该非常相似。

那么,在文本与图像的特征矩阵当中,计算相似度,并且让对角线上的相似度高,非对角线上的相似度低就是我们的训练目标。于是,CLIP模型直接从自然语言的描述中学习,利用了网上海量的现成数据。同时,它理解学会了文字的深层含义,可以提取到图像的更抽象含义。
二、官方代码测试
首先,安装PyTorch 1.7.1(或更高版本)和torchvision,以及一些小的附加依赖项,然后将此repo作为Python包进行安装。在配备CUDA GPU的机器上,执行以下操作即可:
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
在安装到没有GPU的机器上时,请将上述的cudatoolkit=11.0替换为您机器上适当的CUDA版本或cpuonly。
接下来进行测试:逻辑是通过把图片的特征向量和所有类别的文本向量的相似度进行计算得出最高的5个。
import os
import clip
import torch
from torchvision.datasets import CIFAR100
from PIL import Image
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
# image, class_id = cifar100[3637]
image = Image.open('rabbit.jpg') # 指定一张图片,比如说兔子
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
我输入了一张兔子的图片:

输出结果是这样的:
Top predictions:
rabbit: 99.61%
lawn_mower: 0.16%
mouse: 0.04%
kangaroo: 0.03%
squirrel: 0.03%
三、实际搭建以及训练
接下来我想简单搭建一个CLIP模型。CLIP无非使用两个编码器,官方采用resnet以及transformer。这里我直接使用transformers库的CLIPModel进行搭建。注意这里默认是直接下载Hunggingface的预训练模型,如果网络问题可以下载到本地再读取。
class CLIPForCIFAR(nn.Module):
"""Thin wrapper around Hugging Face CLIPModel to expose forward and projection features."""
# 可以直接加载或者本地模型
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
super().__init__()
self.model = CLIPModel.from_pretrained('clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained('clip-vit-base-patch32')
def forward(self, batch: CLIPBatch):
outputs = self.model(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
pixel_values=batch.pixel_values,
return_dict=True,
)
return outputs # contains logits_per_image (B, B), logits_per_text (B, B)
@torch.no_grad()
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
out = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
return F.normalize(out, dim=-1)
@torch.no_grad()
def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
out = self.model.get_image_features(pixel_values=pixel_values)
return F.normalize(out, dim=-1)
接下来创建dataset类读取数据集,如果要拿PyTorch TorchVision库提供的现成数据集接口举例,则要使用以下的代码进行读取:
class CIFAR100CLIPDataset(torch.utils.data.Dataset):
"""将每个图像与其类标签设计的单个提示配对."""
def __init__(self, root: str, split: str, processor: CLIPProcessor, templates: List[str] = None):
assert split in {"train", "test"}
self.ds = tvdatasets.CIFAR100(root=root, train=(split == "train"), download=True)
self.processor = processor
self.templates = templates or ["a photo of a {label}."]
self.classes = self.ds.classes # list of 100 class names
# CLIP默认要求ViT-B/32为224x224
image_size = processor.image_processor.crop_size["height"] # 或者取 "width",两者一样
self.img_transform = T.Compose([
T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(mean=processor.image_processor.image_mean,
std=processor.image_processor.image_std),
])
def __len__(self):
return len(self.ds)
def __getitem__(self, idx: int):
img, label = self.ds[idx]
# CLIP数据要求
img = self.img_transform(img)
# 随机选择一个提示模板来增加训练时间
label_text = random.choice(self.templates).format(label=self.classes[label])
enc = self.processor.tokenizer(
label_text,
padding="max_length",
truncation=True,
return_tensors="pt",
)
item = {
"pixel_values": img,
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"target": label,
}
return item
定义数据集加载器:
def build_dataloaders(
root: str,
processor: CLIPProcessor,
batch_size: int = 256,
num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:
train_set =CIFAR100CLIPDataset(root=root, split='train',processor=processor, templates=CIFAR100_TEMPLATES)
test_set = CIFAR100CLIPDataset(root=root,split='test', processor=processor, templates=CIFAR100_TEMPLATES)
def collate_fn(batch): # 将样本列表组合成批次
pixel_values = torch.stack([b["pixel_values"] for b in batch])
input_ids = torch.stack([b["input_ids"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
# labels are simply 0..B-1 (diagonal matching)
bsz = pixel_values.size(0)
labels = torch.arange(bsz)
return CLIPBatch(pixel_values, input_ids, attention_mask, labels)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
test_loader = DataLoader(
test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
return train_loader, test_loader, train_set.classes
如果想对自己的数据集进行微调,则需要写另外的数据集加载版本,这个数据集类适用于任意图片数据集ImageFolder 格式:
dataset_root/
train/
class1/
img1.jpg
img2.jpg
...
class2/
...
val/
class1/
...
class2/
...
from torchvision.datasets import ImageFolder
class ImageFolderCLIPDataset(torch.utils.data.Dataset):
def __init__(self, root, processor, templates=None):
self.ds = ImageFolder(root)
self.processor = processor
self.templates = templates or ["a photo of a {label}."]
self.classes = self.ds.classes
self.img_transform = T.Compose([
T.Resize(processor.image_processor.crop_size["height"], interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(processor.image_processor.crop_size["height"]),
T.ToTensor(),
T.Normalize(mean=processor.image_processor.image_mean,
std=processor.image_processor.image_std),
])
def __getitem__(self, idx):
img, label = self.ds[idx]
img = self.img_transform(img)
label_text = random.choice(self.templates).format(label=self.classes[label])
enc = self.processor.tokenizer(label_text, padding="max_length", truncation=True, return_tensors="pt")
return {
"pixel_values": img,
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"target": label,
}
def __len__(self):
return len(self.ds)
***完整代码***
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Train (optional fine-tune) and evaluate a CLIP model for zero-shot classification on CIFAR-100.
Requirements:
- torch, torchvision
- transformers >= 4.41
- accelerate (optional)
Example usage:
python clip_cifar100_zeroshot.py --epochs 5 --batch-size 256 --lr 5e-6 \
--model openai/clip-vit-base-patch32 --data-root ./data --amp
Evaluate only (no training):
python clip_cifar100_zeroshot.py --eval-only --model openai/clip-vit-base-patch32
Save / load:
python clip_cifar100_zeroshot.py --epochs 2 --save-path ./clip_cifar100.pt
python clip_cifar100_zeroshot.py --eval-only --load-path ./clip_cifar100.pt
This script includes:
- Dataset and dataloader for CIFAR-100
- Prompt engineering templates
- CLIP model/processor setup
- Contrastive training loop (image-text)
- Zero-shot evaluation using class-name prompts
- (Optional) linear-probing head for supervised classification (off by default)
"""
import argparse
import math
import os
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets as tvdatasets
from torchvision import transforms as T
from transformers import (
CLIPModel,
CLIPProcessor,
CLIPTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
# -------------------------
# Utilities
# -------------------------
# 随机种子
def set_seed(seed: int = 42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def exists(x):
return x is not None
# -------------------------
# 提示词工程。用于将类别转为文本
# -------------------------
CIFAR100_TEMPLATES = [
"a photo of a {label}.",
"a blurry photo of a {label}.",
"a photo of the {label}.",
"a close-up photo of a {label}.",
"a bright photo of a {label}.",
"a cropped photo of a {label}.",
"a photo of a small {label}.",
"a photo of a big {label}.",
"a low contrast photo of a {label}.",
"a high contrast photo of a {label}.",
]
# -------------------------
# Data: CIFAR-100
# -------------------------
@dataclass
class CLIPBatch:
pixel_values: torch.Tensor # (B, C, H, W)
input_ids: torch.Tensor # (B, L)
attention_mask: torch.Tensor # (B, L)
labels: torch.Tensor # (B,) image-text matching along the batch diagonal
class CIFAR100CLIPDataset(torch.utils.data.Dataset):
"""将每个图像与其类标签设计的单个提示配对."""
def __init__(self, root: str, split: str, processor: CLIPProcessor, templates: List[str] = None):
assert split in {"train", "test"}
self.ds = tvdatasets.CIFAR100(root=root, train=(split == "train"), download=True)
self.processor = processor
self.templates = templates or ["a photo of a {label}."]
self.classes = self.ds.classes # list of 100 class names
# CLIP默认要求ViT-B/32为224x224
image_size = processor.image_processor.crop_size["height"] # 或者取 "width",两者一样
self.img_transform = T.Compose([
T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(mean=processor.image_processor.image_mean,
std=processor.image_processor.image_std),
])
def __len__(self):
return len(self.ds)
def __getitem__(self, idx: int):
img, label = self.ds[idx]
# CLIP数据要求
img = self.img_transform(img)
# 随机选择一个提示模板来增加训练时间
label_text = random.choice(self.templates).format(label=self.classes[label])
enc = self.processor.tokenizer(
label_text,
padding="max_length",
truncation=True,
return_tensors="pt",
)
item = {
"pixel_values": img,
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"target": label,
}
return item
# from torchvision.datasets import ImageFolder
#
# class ImageFolderCLIPDataset(torch.utils.data.Dataset):
# def __init__(self, root, processor, templates=None):
# self.ds = ImageFolder(root)
# self.processor = processor
# self.templates = templates or ["a photo of a {label}."]
# self.classes = self.ds.classes
# self.img_transform = T.Compose([
# T.Resize(processor.image_processor.crop_size["height"], interpolation=T.InterpolationMode.BICUBIC),
# T.CenterCrop(processor.image_processor.crop_size["height"]),
# T.ToTensor(),
# T.Normalize(mean=processor.image_processor.image_mean,
# std=processor.image_processor.image_std),
# ])
#
# def __getitem__(self, idx):
# img, label = self.ds[idx]
# img = self.img_transform(img)
# label_text = random.choice(self.templates).format(label=self.classes[label])
# enc = self.processor.tokenizer(label_text, padding="max_length", truncation=True, return_tensors="pt")
# return {
# "pixel_values": img,
# "input_ids": enc["input_ids"].squeeze(0),
# "attention_mask": enc["attention_mask"].squeeze(0),
# "target": label,
# }
#
# def __len__(self):
# return len(self.ds)
# 数据集加载
def build_dataloaders(
root: str,
processor: CLIPProcessor,
batch_size: int = 256,
num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:
train_set =CIFAR100CLIPDataset(root=root, split='train',processor=processor, templates=CIFAR100_TEMPLATES)
test_set = CIFAR100CLIPDataset(root=root,split='test', processor=processor, templates=CIFAR100_TEMPLATES)
def collate_fn(batch): # 将样本列表组合成批次
pixel_values = torch.stack([b["pixel_values"] for b in batch])
input_ids = torch.stack([b["input_ids"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
# labels are simply 0..B-1 (diagonal matching)
bsz = pixel_values.size(0)
labels = torch.arange(bsz)
return CLIPBatch(pixel_values, input_ids, attention_mask, labels)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
test_loader = DataLoader(
test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
return train_loader, test_loader, train_set.classes
# -------------------------
# Model 包装
# -------------------------
class CLIPForCIFAR(nn.Module):
"""Thin wrapper around Hugging Face CLIPModel to expose forward and projection features."""
# 可以直接加载或者本地模型
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
super().__init__()
self.model = CLIPModel.from_pretrained('clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained('clip-vit-base-patch32')
def forward(self, batch: CLIPBatch):
outputs = self.model(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
pixel_values=batch.pixel_values,
return_dict=True,
)
return outputs # contains logits_per_image (B, B), logits_per_text (B, B)
@torch.no_grad()
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
out = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
return F.normalize(out, dim=-1)
@torch.no_grad()
def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
out = self.model.get_image_features(pixel_values=pixel_values)
return F.normalize(out, dim=-1)
# -------------------------
# Loss (InfoNCE over CLIP logits)
# -------------------------
def clip_contrastive_loss(logits_per_image: torch.Tensor, logits_per_text: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""标准对比损失函数"""
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
return (loss_i + loss_t) / 2
# -------------------------
# Training loop
# -------------------------
def train(
model: CLIPForCIFAR,
train_loader: DataLoader,
device: torch.device,
epochs: int = 5,
lr: float = 5e-6,
weight_decay: float = 0.2,
amp: bool = False,
freeze_vision: bool = False,
freeze_text: bool = False,
grad_accum_steps: int = 1,
save_path: Optional[str] = None,
):
model.train()
# Optionally freeze encoders (useful for quick runs)
if freeze_vision:
for p in model.model.vision_model.parameters():
p.requires_grad = False
if freeze_text:
for p in model.model.text_model.parameters():
p.requires_grad = False
# Only optimize trainable params
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
scaler = torch.cuda.amp.GradScaler(enabled=amp)
global_step = 0
for epoch in range(epochs):
running_loss = 0.0
for step, batch in enumerate(train_loader):
batch = CLIPBatch(
pixel_values=batch.pixel_values.to(device, non_blocking=True),
input_ids=batch.input_ids.to(device, non_blocking=True),
attention_mask=batch.attention_mask.to(device, non_blocking=True),
labels=batch.labels.to(device, non_blocking=True),
)
with torch.cuda.amp.autocast(enabled=amp):
outputs = model(batch)
loss = clip_contrastive_loss(outputs.logits_per_image, outputs.logits_per_text, batch.labels)
loss = loss / grad_accum_steps
scaler.scale(loss).backward()
if (step + 1) % grad_accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
global_step += 1
running_loss += loss.item() * grad_accum_steps
if (step + 1) % 50 == 0:
avg = running_loss / 50
print(f"Epoch {epoch+1} | Step {step+1}/{len(train_loader)} | loss {avg:.4f}")
running_loss = 0.0
if exists(save_path):
ckpt = {
"model_state": model.state_dict(),
"epoch": epoch + 1,
}
torch.save(ckpt, save_path)
print(f"[Saved] {save_path} at epoch {epoch+1}")
# -------------------------
# Zero-shot evaluation
# -------------------------
@torch.no_grad()
def build_text_classifier(model: CLIPForCIFAR, classnames: List[str], templates: List[str], device: torch.device):
"""每个类别生成多个提示,编码并平均得到类别特征
Returns: text_features (C, D), where each row is the normalized class embedding.
"""
tokenizer = model.processor.tokenizer
all_class_embeds = []
for cls in classnames:
# Encode multiple prompts per class and average
texts = [template.format(label=cls) for template in templates]
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
class_feats = model.encode_text(enc["input_ids"], enc["attention_mask"]) # (T, D)
class_feats = class_feats.mean(dim=0)
class_feats = F.normalize(class_feats, dim=-1)
all_class_embeds.append(class_feats)
text_features = torch.stack(all_class_embeds, dim=0) # (C, D)
return text_features
@torch.no_grad()
def zero_shot_eval(model: CLIPForCIFAR, loader: DataLoader, classnames: List[str], device: torch.device) -> float:
model.eval()
text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device)
correct = 0
total = 0
for batch in loader:
pixel_values = batch.pixel_values.to(device)
targets = batch.labels.to(device)
pass
# 我们需要一个信息量更大的collate_fn,它也返回真实的类索引进行评估
@dataclass
class CLIPBatchFull(CLIPBatch):
targets: torch.Tensor # (B,) true CIFAR-100 labels
def build_dataloaders_full(
root: str,
processor: CLIPProcessor,
batch_size: int = 256,
num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:
train_set = CIFAR100CLIPDataset(root=root,split='train', processor=processor, templates=CIFAR100_TEMPLATES)
test_set =CIFAR100CLIPDataset(root=root, split='test', processor=processor, templates=CIFAR100_TEMPLATES)
def collate_fn(batch):
pixel_values = torch.stack([b["pixel_values"] for b in batch])
input_ids = torch.stack([b["input_ids"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
diagonal = torch.arange(pixel_values.size(0))
targets = torch.tensor([b["target"] for b in batch], dtype=torch.long)
return CLIPBatchFull(pixel_values, input_ids, attention_mask, diagonal, targets)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
test_loader = DataLoader(
test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate_fn,
)
return train_loader, test_loader, train_set.classes
@torch.no_grad()
def zero_shot_eval(model: CLIPForCIFAR, loader: DataLoader, classnames: List[str], device: torch.device) -> float:
model.eval()
text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device) # (C, D)
correct = 0
total = 0
for batch in loader:
pixel_values = batch.pixel_values.to(device)
targets = batch.targets.to(device)
image_features = model.encode_image(pixel_values) # (B, D)
# similarity (B, C)
logits = image_features @ text_features.t()
preds = logits.argmax(dim=-1)
correct += (preds == targets).sum().item()
total += targets.size(0)
acc = correct / total
return acc
from PIL import Image
@torch.no_grad()
def predict_single_image(model, image_path, classnames, device):
model.eval()
# 构建 zero-shot 分类器
text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device)
# 加载图片
image = Image.open(image_path).convert("RGB")
inputs = model.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
# 提取图像特征
image_features = model.encode_image(pixel_values)
# 确保特征归一化(encode_image 已经做了归一化,但为了清晰可以再显式做一次)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算相似度并转换为百分比
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# 准备结果
results = []
for value, index in zip(values, indices):
class_name = classnames[index]
percent_prob = value.item() * 100 # 已经是百分比格式
results.append((class_name, percent_prob))
return results
# -------------------------
# Main
# -------------------------
def main():
parser = argparse.ArgumentParser(description="CLIP zero-shot on CIFAR-100")
parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch32", help="CLIP model name")
parser.add_argument("--data-root", type=str, default="./data", help="Directory for CIFAR-100")
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=0, help="Training epochs (0 = skip training)")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--weight-decay", type=float, default=0.2)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--amp", action="store_true", help="Use mixed precision")
parser.add_argument("--freeze-vision", action="store_true")
parser.add_argument("--freeze-text", action="store_true")
parser.add_argument("--grad-accum-steps", type=int, default=1)
parser.add_argument("--save-path", type=str, default=None)
parser.add_argument("--load-path", type=str, default=None)
parser.add_argument("--eval-only", action="store_true")
args = parser.parse_args()
set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Build model and processor
model = CLIPForCIFAR(model_name=args.model)
model.to(device)
train_root = os.path.join(args.data_root, "train")
# Build data
train_loader, test_loader, classnames = build_dataloaders_full(
root=args.data_root,
processor=model.processor,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
# Optional loading
if exists(args.load_path) and os.path.isfile(args.load_path):
ckpt = torch.load(args.load_path, map_location="cpu")
model.load_state_dict(ckpt["model_state"], strict=False)
print(f"[Loaded] {args.load_path}")
# Train (optional)
if not args.eval_only and args.epochs > 0:
train(
model=model,
train_loader=train_loader,
device=device,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
amp=args.amp,
freeze_vision=args.freeze_vision,
freeze_text=args.freeze_text,
grad_accum_steps=args.grad_accum_steps,
save_path=args.save_path,
)
# Zero-shot evaluation
# acc = zero_shot_eval(model, test_loader, classnames, device)
# print(f"Zero-shot Top-1 Accuracy on CIFAR-100: {acc * 100:.2f}%")
# ======================================================================
image_path = r"90.jpg" # 预测图片的读取路径
results = predict_single_image(model, image_path, classnames, device)
print("\n预测结果 Top-5:\n")
for cls, prob_percent in results:
print(f"{cls:>16s}: {prob_percent:.2f}%")
#======================================================================
if __name__ == "__main__":
main()
如果想要直接使用官方模型clip-vit-base-patch32,则直接将epoch设置成0,跳过训练并指定image_path为指定路径图片,我使用的还是上面那张图,效果如下:
预测结果 Top-5:
rabbit: 99.85%
lawn_mower: 0.04%
mouse: 0.03%
kangaroo: 0.02%
squirrel: 0.01%
预测成功!
四、自己训练集微调
需要将dataset改为ImageFolderCLIPDataset形式,并且做出以下修改:
train_root = os.path.join(args.data_root, "train")
# Build data
train_loader, test_loader, classnames = build_dataloaders_full(
root=args.data_root, # <--将这里改为root = train_root
processor=model.processor,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
然后进行epoch和模型路径保存的修改即可!
以上即为全部内容!CLIP模型最厉害的是实现了zero-shot,将固定的分类集合转化为完全依靠对自然语言的理解的开放式分类集合。
5232

被折叠的 条评论
为什么被折叠?



