◉ Perception Encorder (CLIP) for ReID#
Tested on 5090 laptop GPU. Installation steps for PE-Core-L14-336 checkpoitns
git clone https://github.com/facebookresearch/perception_models.git
cd perception_models
conda create -n pe_test python=3.12 -y
conda activate pe_test
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install pillow tqdm numpy pandas scikit-learn matplotlib
pip install huggingface_hub
pip install -e .
download model locally
hf download facebook/PE-Core-L14-336 \
--local-dir models
from pathlib import Path
import time
import json
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from helpers import *
ROOT = Path("/home/user/data/datasets/VeRi_Self")
REID_ROOT = ROOT / "reid"
DISTRACTOR_ROOT = ROOT / "PKU_Vehicle"
METADATA_CSV = REID_ROOT / "all_metadata.csv"
EMB_DIR = Path("outputs/embeddings")
EMB_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR = Path("outputs/evals")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
print("Device:", DEVICE)
print("CUDA device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
Device: cuda
CUDA device: NVIDIA GeForce RTX 5090 Laptop GPU
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
from pathlib import Path
MODEL_NAME = "PE-Core-L14-336"
LOCAL_MODEL_DIR = Path("models")
ckpt_path = LOCAL_MODEL_DIR / "PE-Core-L14-336.pt"
config_path = LOCAL_MODEL_DIR / "config.yaml"
print(ckpt_path.exists(), ckpt_path)
print(config_path.exists(), config_path)
print("Checkpoint size GB:", round(ckpt_path.stat().st_size / 1024**3, 2))
True models/PE-Core-L14-336.pt
True models/config.yaml
Checkpoint size GB: 2.5
Load the PE model
import time
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
t = time.time()
ckpt = torch.load(ckpt_path, map_location="cpu")
print("1. torch.load:", round(time.time() - t, 2), "sec")
if isinstance(ckpt, dict):
print("Checkpoint keys:", ckpt.keys())
t = time.time()
model = pe.CLIP.from_config(MODEL_NAME, pretrained=False)
print("2. model build:", round(time.time() - t, 2), "sec")
if isinstance(ckpt, dict):
if "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
elif "model" in ckpt:
state_dict = ckpt["model"]
elif "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
state_dict = ckpt
else:
raise ValueError("Unexpected checkpoint format")
clean_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("module.", "").replace("model.", "")
clean_state_dict[new_k] = v
t = time.time()
missing, unexpected = model.load_state_dict(clean_state_dict, strict=False)
print("3. load_state_dict:", round(time.time() - t, 2), "sec")
print("Missing:", len(missing), "Unexpected:", len(unexpected))
print("Missing sample:", missing[:5])
print("Unexpected sample:", unexpected[:5])
t = time.time()
model = model.to(DEVICE).eval()
if DEVICE == "cuda":
torch.cuda.synchronize()
print("4. to device:", round(time.time() - t, 2), "sec")
preprocess = transforms.get_image_transform(model.image_size)
tokenizer = transforms.get_text_tokenizer(model.context_length)
print("Done")
print("Image size:", model.image_size)
print("Context length:", model.context_length)
p = next(model.parameters())
print("Model device:", p.device)
print("Model dtype:", p.dtype)
model
1. torch.load: 1.04 sec
Checkpoint keys: dict_keys(['ln_final.bias', 'ln_final.weight', 'logit_scale', 'positional_embedding', 'text_projection', 'token_embedding.weight', 'transformer.resblocks.0.attn.in_proj_bias', 'transformer.resblocks.0.attn.in_proj_weight', 'transformer.resblocks.0.attn.out_proj.bias', 'transformer.resblocks.0.attn.out_proj.weight', 'transformer.resblocks.0.ln_1.bias', 'transformer.resblocks.0.ln_1.weight', 'transformer.resblocks.0.ln_2.bias', 'transformer.resblocks.0.ln_2.weight', 'transformer.resblocks.0.mlp.c_fc.bias', 'transformer.resblocks.0.mlp.c_fc.weight', 'transformer.resblocks.0.mlp.c_proj.bias', 'transformer.resblocks.0.mlp.c_proj.weight', 'transformer.resblocks.1.attn.in_proj_bias', 'transformer.resblocks.1.attn.in_proj_weight', 'transformer.resblocks.1.attn.out_proj.bias', 'transformer.resblocks.1.attn.out_proj.weight', 'transformer.resblocks.1.ln_1.bias', 'transformer.resblocks.1.ln_1.weight', 'transformer.resblocks.1.ln_2.bias', 'transformer.resblocks.1.ln_2.weight', 'transformer.resblocks.1.mlp.c_fc.bias', 'transformer.resblocks.1.mlp.c_fc.weight', 'transformer.resblocks.1.mlp.c_proj.bias', 'transformer.resblocks.1.mlp.c_proj.weight', 'transformer.resblocks.10.attn.in_proj_bias', 'transformer.resblocks.10.attn.in_proj_weight', 'transformer.resblocks.10.attn.out_proj.bias', 'transformer.resblocks.10.attn.out_proj.weight', 'transformer.resblocks.10.ln_1.bias', 'transformer.resblocks.10.ln_1.weight', 'transformer.resblocks.10.ln_2.bias', 'transformer.resblocks.10.ln_2.weight', 'transformer.resblocks.10.mlp.c_fc.bias', 'transformer.resblocks.10.mlp.c_fc.weight', 'transformer.resblocks.10.mlp.c_proj.bias', 'transformer.resblocks.10.mlp.c_proj.weight', 'transformer.resblocks.11.attn.in_proj_bias', 'transformer.resblocks.11.attn.in_proj_weight', 'transformer.resblocks.11.attn.out_proj.bias', 'transformer.resblocks.11.attn.out_proj.weight', 'transformer.resblocks.11.ln_1.bias', 'transformer.resblocks.11.ln_1.weight', 'transformer.resblocks.11.ln_2.bias', 'transformer.resblocks.11.ln_2.weight', 'transformer.resblocks.11.mlp.c_fc.bias', 'transformer.resblocks.11.mlp.c_fc.weight', 'transformer.resblocks.11.mlp.c_proj.bias', 'transformer.resblocks.11.mlp.c_proj.weight', 'transformer.resblocks.12.attn.in_proj_bias', 'transformer.resblocks.12.attn.in_proj_weight', 'transformer.resblocks.12.attn.out_proj.bias', 'transformer.resblocks.12.attn.out_proj.weight', 'transformer.resblocks.12.ln_1.bias', 'transformer.resblocks.12.ln_1.weight', 'transformer.resblocks.12.ln_2.bias', 'transformer.resblocks.12.ln_2.weight', 'transformer.resblocks.12.mlp.c_fc.bias', 'transformer.resblocks.12.mlp.c_fc.weight', 'transformer.resblocks.12.mlp.c_proj.bias', 'transformer.resblocks.12.mlp.c_proj.weight', 'transformer.resblocks.13.attn.in_proj_bias', 'transformer.resblocks.13.attn.in_proj_weight', 'transformer.resblocks.13.attn.out_proj.bias', 'transformer.resblocks.13.attn.out_proj.weight', 'transformer.resblocks.13.ln_1.bias', 'transformer.resblocks.13.ln_1.weight', 'transformer.resblocks.13.ln_2.bias', 'transformer.resblocks.13.ln_2.weight', 'transformer.resblocks.13.mlp.c_fc.bias', 'transformer.resblocks.13.mlp.c_fc.weight', 'transformer.resblocks.13.mlp.c_proj.bias', 'transformer.resblocks.13.mlp.c_proj.weight', 'transformer.resblocks.14.attn.in_proj_bias', 'transformer.resblocks.14.attn.in_proj_weight', 'transformer.resblocks.14.attn.out_proj.bias', 'transformer.resblocks.14.attn.out_proj.weight', 'transformer.resblocks.14.ln_1.bias', 'transformer.resblocks.14.ln_1.weight', 'transformer.resblocks.14.ln_2.bias', 'transformer.resblocks.14.ln_2.weight', 'transformer.resblocks.14.mlp.c_fc.bias', 'transformer.resblocks.14.mlp.c_fc.weight', 'transformer.resblocks.14.mlp.c_proj.bias', 'transformer.resblocks.14.mlp.c_proj.weight', 'transformer.resblocks.15.attn.in_proj_bias', 'transformer.resblocks.15.attn.in_proj_weight', 'transformer.resblocks.15.attn.out_proj.bias', 'transformer.resblocks.15.attn.out_proj.weight', 'transformer.resblocks.15.ln_1.bias', 'transformer.resblocks.15.ln_1.weight', 'transformer.resblocks.15.ln_2.bias', 'transformer.resblocks.15.ln_2.weight', 'transformer.resblocks.15.mlp.c_fc.bias', 'transformer.resblocks.15.mlp.c_fc.weight', 'transformer.resblocks.15.mlp.c_proj.bias', 'transformer.resblocks.15.mlp.c_proj.weight', 'transformer.resblocks.16.attn.in_proj_bias', 'transformer.resblocks.16.attn.in_proj_weight', 'transformer.resblocks.16.attn.out_proj.bias', 'transformer.resblocks.16.attn.out_proj.weight', 'transformer.resblocks.16.ln_1.bias', 'transformer.resblocks.16.ln_1.weight', 'transformer.resblocks.16.ln_2.bias', 'transformer.resblocks.16.ln_2.weight', 'transformer.resblocks.16.mlp.c_fc.bias', 'transformer.resblocks.16.mlp.c_fc.weight', 'transformer.resblocks.16.mlp.c_proj.bias', 'transformer.resblocks.16.mlp.c_proj.weight', 'transformer.resblocks.17.attn.in_proj_bias', 'transformer.resblocks.17.attn.in_proj_weight', 'transformer.resblocks.17.attn.out_proj.bias', 'transformer.resblocks.17.attn.out_proj.weight', 'transformer.resblocks.17.ln_1.bias', 'transformer.resblocks.17.ln_1.weight', 'transformer.resblocks.17.ln_2.bias', 'transformer.resblocks.17.ln_2.weight', 'transformer.resblocks.17.mlp.c_fc.bias', 'transformer.resblocks.17.mlp.c_fc.weight', 'transformer.resblocks.17.mlp.c_proj.bias', 'transformer.resblocks.17.mlp.c_proj.weight', 'transformer.resblocks.18.attn.in_proj_bias', 'transformer.resblocks.18.attn.in_proj_weight', 'transformer.resblocks.18.attn.out_proj.bias', 'transformer.resblocks.18.attn.out_proj.weight', 'transformer.resblocks.18.ln_1.bias', 'transformer.resblocks.18.ln_1.weight', 'transformer.resblocks.18.ln_2.bias', 'transformer.resblocks.18.ln_2.weight', 'transformer.resblocks.18.mlp.c_fc.bias', 'transformer.resblocks.18.mlp.c_fc.weight', 'transformer.resblocks.18.mlp.c_proj.bias', 'transformer.resblocks.18.mlp.c_proj.weight', 'transformer.resblocks.19.attn.in_proj_bias', 'transformer.resblocks.19.attn.in_proj_weight', 'transformer.resblocks.19.attn.out_proj.bias', 'transformer.resblocks.19.attn.out_proj.weight', 'transformer.resblocks.19.ln_1.bias', 'transformer.resblocks.19.ln_1.weight', 'transformer.resblocks.19.ln_2.bias', 'transformer.resblocks.19.ln_2.weight', 'transformer.resblocks.19.mlp.c_fc.bias', 'transformer.resblocks.19.mlp.c_fc.weight', 'transformer.resblocks.19.mlp.c_proj.bias', 'transformer.resblocks.19.mlp.c_proj.weight', 'transformer.resblocks.2.attn.in_proj_bias', 'transformer.resblocks.2.attn.in_proj_weight', 'transformer.resblocks.2.attn.out_proj.bias', 'transformer.resblocks.2.attn.out_proj.weight', 'transformer.resblocks.2.ln_1.bias', 'transformer.resblocks.2.ln_1.weight', 'transformer.resblocks.2.ln_2.bias', 'transformer.resblocks.2.ln_2.weight', 'transformer.resblocks.2.mlp.c_fc.bias', 'transformer.resblocks.2.mlp.c_fc.weight', 'transformer.resblocks.2.mlp.c_proj.bias', 'transformer.resblocks.2.mlp.c_proj.weight', 'transformer.resblocks.20.attn.in_proj_bias', 'transformer.resblocks.20.attn.in_proj_weight', 'transformer.resblocks.20.attn.out_proj.bias', 'transformer.resblocks.20.attn.out_proj.weight', 'transformer.resblocks.20.ln_1.bias', 'transformer.resblocks.20.ln_1.weight', 'transformer.resblocks.20.ln_2.bias', 'transformer.resblocks.20.ln_2.weight', 'transformer.resblocks.20.mlp.c_fc.bias', 'transformer.resblocks.20.mlp.c_fc.weight', 'transformer.resblocks.20.mlp.c_proj.bias', 'transformer.resblocks.20.mlp.c_proj.weight', 'transformer.resblocks.21.attn.in_proj_bias', 'transformer.resblocks.21.attn.in_proj_weight', 'transformer.resblocks.21.attn.out_proj.bias', 'transformer.resblocks.21.attn.out_proj.weight', 'transformer.resblocks.21.ln_1.bias', 'transformer.resblocks.21.ln_1.weight', 'transformer.resblocks.21.ln_2.bias', 'transformer.resblocks.21.ln_2.weight', 'transformer.resblocks.21.mlp.c_fc.bias', 'transformer.resblocks.21.mlp.c_fc.weight', 'transformer.resblocks.21.mlp.c_proj.bias', 'transformer.resblocks.21.mlp.c_proj.weight', 'transformer.resblocks.22.attn.in_proj_bias', 'transformer.resblocks.22.attn.in_proj_weight', 'transformer.resblocks.22.attn.out_proj.bias', 'transformer.resblocks.22.attn.out_proj.weight', 'transformer.resblocks.22.ln_1.bias', 'transformer.resblocks.22.ln_1.weight', 'transformer.resblocks.22.ln_2.bias', 'transformer.resblocks.22.ln_2.weight', 'transformer.resblocks.22.mlp.c_fc.bias', 'transformer.resblocks.22.mlp.c_fc.weight', 'transformer.resblocks.22.mlp.c_proj.bias', 'transformer.resblocks.22.mlp.c_proj.weight', 'transformer.resblocks.23.attn.in_proj_bias', 'transformer.resblocks.23.attn.in_proj_weight', 'transformer.resblocks.23.attn.out_proj.bias', 'transformer.resblocks.23.attn.out_proj.weight', 'transformer.resblocks.23.ln_1.bias', 'transformer.resblocks.23.ln_1.weight', 'transformer.resblocks.23.ln_2.bias', 'transformer.resblocks.23.ln_2.weight', 'transformer.resblocks.23.mlp.c_fc.bias', 'transformer.resblocks.23.mlp.c_fc.weight', 'transformer.resblocks.23.mlp.c_proj.bias', 'transformer.resblocks.23.mlp.c_proj.weight', 'transformer.resblocks.3.attn.in_proj_bias', 'transformer.resblocks.3.attn.in_proj_weight', 'transformer.resblocks.3.attn.out_proj.bias', 'transformer.resblocks.3.attn.out_proj.weight', 'transformer.resblocks.3.ln_1.bias', 'transformer.resblocks.3.ln_1.weight', 'transformer.resblocks.3.ln_2.bias', 'transformer.resblocks.3.ln_2.weight', 'transformer.resblocks.3.mlp.c_fc.bias', 'transformer.resblocks.3.mlp.c_fc.weight', 'transformer.resblocks.3.mlp.c_proj.bias', 'transformer.resblocks.3.mlp.c_proj.weight', 'transformer.resblocks.4.attn.in_proj_bias', 'transformer.resblocks.4.attn.in_proj_weight', 'transformer.resblocks.4.attn.out_proj.bias', 'transformer.resblocks.4.attn.out_proj.weight', 'transformer.resblocks.4.ln_1.bias', 'transformer.resblocks.4.ln_1.weight', 'transformer.resblocks.4.ln_2.bias', 'transformer.resblocks.4.ln_2.weight', 'transformer.resblocks.4.mlp.c_fc.bias', 'transformer.resblocks.4.mlp.c_fc.weight', 'transformer.resblocks.4.mlp.c_proj.bias', 'transformer.resblocks.4.mlp.c_proj.weight', 'transformer.resblocks.5.attn.in_proj_bias', 'transformer.resblocks.5.attn.in_proj_weight', 'transformer.resblocks.5.attn.out_proj.bias', 'transformer.resblocks.5.attn.out_proj.weight', 'transformer.resblocks.5.ln_1.bias', 'transformer.resblocks.5.ln_1.weight', 'transformer.resblocks.5.ln_2.bias', 'transformer.resblocks.5.ln_2.weight', 'transformer.resblocks.5.mlp.c_fc.bias', 'transformer.resblocks.5.mlp.c_fc.weight', 'transformer.resblocks.5.mlp.c_proj.bias', 'transformer.resblocks.5.mlp.c_proj.weight', 'transformer.resblocks.6.attn.in_proj_bias', 'transformer.resblocks.6.attn.in_proj_weight', 'transformer.resblocks.6.attn.out_proj.bias', 'transformer.resblocks.6.attn.out_proj.weight', 'transformer.resblocks.6.ln_1.bias', 'transformer.resblocks.6.ln_1.weight', 'transformer.resblocks.6.ln_2.bias', 'transformer.resblocks.6.ln_2.weight', 'transformer.resblocks.6.mlp.c_fc.bias', 'transformer.resblocks.6.mlp.c_fc.weight', 'transformer.resblocks.6.mlp.c_proj.bias', 'transformer.resblocks.6.mlp.c_proj.weight', 'transformer.resblocks.7.attn.in_proj_bias', 'transformer.resblocks.7.attn.in_proj_weight', 'transformer.resblocks.7.attn.out_proj.bias', 'transformer.resblocks.7.attn.out_proj.weight', 'transformer.resblocks.7.ln_1.bias', 'transformer.resblocks.7.ln_1.weight', 'transformer.resblocks.7.ln_2.bias', 'transformer.resblocks.7.ln_2.weight', 'transformer.resblocks.7.mlp.c_fc.bias', 'transformer.resblocks.7.mlp.c_fc.weight', 'transformer.resblocks.7.mlp.c_proj.bias', 'transformer.resblocks.7.mlp.c_proj.weight', 'transformer.resblocks.8.attn.in_proj_bias', 'transformer.resblocks.8.attn.in_proj_weight', 'transformer.resblocks.8.attn.out_proj.bias', 'transformer.resblocks.8.attn.out_proj.weight', 'transformer.resblocks.8.ln_1.bias', 'transformer.resblocks.8.ln_1.weight', 'transformer.resblocks.8.ln_2.bias', 'transformer.resblocks.8.ln_2.weight', 'transformer.resblocks.8.mlp.c_fc.bias', 'transformer.resblocks.8.mlp.c_fc.weight', 'transformer.resblocks.8.mlp.c_proj.bias', 'transformer.resblocks.8.mlp.c_proj.weight', 'transformer.resblocks.9.attn.in_proj_bias', 'transformer.resblocks.9.attn.in_proj_weight', 'transformer.resblocks.9.attn.out_proj.bias', 'transformer.resblocks.9.attn.out_proj.weight', 'transformer.resblocks.9.ln_1.bias', 'transformer.resblocks.9.ln_1.weight', 'transformer.resblocks.9.ln_2.bias', 'transformer.resblocks.9.ln_2.weight', 'transformer.resblocks.9.mlp.c_fc.bias', 'transformer.resblocks.9.mlp.c_fc.weight', 'transformer.resblocks.9.mlp.c_proj.bias', 'transformer.resblocks.9.mlp.c_proj.weight', 'visual.attn_pool.attn.in_proj_bias', 'visual.attn_pool.attn.in_proj_weight', 'visual.attn_pool.attn.out_proj.bias', 'visual.attn_pool.attn.out_proj.weight', 'visual.attn_pool.layernorm.bias', 'visual.attn_pool.layernorm.weight', 'visual.attn_pool.mlp.c_fc.bias', 'visual.attn_pool.mlp.c_fc.weight', 'visual.attn_pool.mlp.c_proj.bias', 'visual.attn_pool.mlp.c_proj.weight', 'visual.attn_pool.probe', 'visual.class_embedding', 'visual.conv1.weight', 'visual.ln_post.bias', 'visual.ln_post.weight', 'visual.ln_pre.bias', 'visual.ln_pre.weight', 'visual.positional_embedding', 'visual.proj', 'visual.transformer.resblocks.0.attn.in_proj_bias', 'visual.transformer.resblocks.0.attn.in_proj_weight', 'visual.transformer.resblocks.0.attn.out_proj.bias', 'visual.transformer.resblocks.0.attn.out_proj.weight', 'visual.transformer.resblocks.0.ln_1.bias', 'visual.transformer.resblocks.0.ln_1.weight', 'visual.transformer.resblocks.0.ln_2.bias', 'visual.transformer.resblocks.0.ln_2.weight', 'visual.transformer.resblocks.0.mlp.c_fc.bias', 'visual.transformer.resblocks.0.mlp.c_fc.weight', 'visual.transformer.resblocks.0.mlp.c_proj.bias', 'visual.transformer.resblocks.0.mlp.c_proj.weight', 'visual.transformer.resblocks.1.attn.in_proj_bias', 'visual.transformer.resblocks.1.attn.in_proj_weight', 'visual.transformer.resblocks.1.attn.out_proj.bias', 'visual.transformer.resblocks.1.attn.out_proj.weight', 'visual.transformer.resblocks.1.ln_1.bias', 'visual.transformer.resblocks.1.ln_1.weight', 'visual.transformer.resblocks.1.ln_2.bias', 'visual.transformer.resblocks.1.ln_2.weight', 'visual.transformer.resblocks.1.mlp.c_fc.bias', 'visual.transformer.resblocks.1.mlp.c_fc.weight', 'visual.transformer.resblocks.1.mlp.c_proj.bias', 'visual.transformer.resblocks.1.mlp.c_proj.weight', 'visual.transformer.resblocks.10.attn.in_proj_bias', 'visual.transformer.resblocks.10.attn.in_proj_weight', 'visual.transformer.resblocks.10.attn.out_proj.bias', 'visual.transformer.resblocks.10.attn.out_proj.weight', 'visual.transformer.resblocks.10.ln_1.bias', 'visual.transformer.resblocks.10.ln_1.weight', 'visual.transformer.resblocks.10.ln_2.bias', 'visual.transformer.resblocks.10.ln_2.weight', 'visual.transformer.resblocks.10.mlp.c_fc.bias', 'visual.transformer.resblocks.10.mlp.c_fc.weight', 'visual.transformer.resblocks.10.mlp.c_proj.bias', 'visual.transformer.resblocks.10.mlp.c_proj.weight', 'visual.transformer.resblocks.11.attn.in_proj_bias', 'visual.transformer.resblocks.11.attn.in_proj_weight', 'visual.transformer.resblocks.11.attn.out_proj.bias', 'visual.transformer.resblocks.11.attn.out_proj.weight', 'visual.transformer.resblocks.11.ln_1.bias', 'visual.transformer.resblocks.11.ln_1.weight', 'visual.transformer.resblocks.11.ln_2.bias', 'visual.transformer.resblocks.11.ln_2.weight', 'visual.transformer.resblocks.11.mlp.c_fc.bias', 'visual.transformer.resblocks.11.mlp.c_fc.weight', 'visual.transformer.resblocks.11.mlp.c_proj.bias', 'visual.transformer.resblocks.11.mlp.c_proj.weight', 'visual.transformer.resblocks.12.attn.in_proj_bias', 'visual.transformer.resblocks.12.attn.in_proj_weight', 'visual.transformer.resblocks.12.attn.out_proj.bias', 'visual.transformer.resblocks.12.attn.out_proj.weight', 'visual.transformer.resblocks.12.ln_1.bias', 'visual.transformer.resblocks.12.ln_1.weight', 'visual.transformer.resblocks.12.ln_2.bias', 'visual.transformer.resblocks.12.ln_2.weight', 'visual.transformer.resblocks.12.mlp.c_fc.bias', 'visual.transformer.resblocks.12.mlp.c_fc.weight', 'visual.transformer.resblocks.12.mlp.c_proj.bias', 'visual.transformer.resblocks.12.mlp.c_proj.weight', 'visual.transformer.resblocks.13.attn.in_proj_bias', 'visual.transformer.resblocks.13.attn.in_proj_weight', 'visual.transformer.resblocks.13.attn.out_proj.bias', 'visual.transformer.resblocks.13.attn.out_proj.weight', 'visual.transformer.resblocks.13.ln_1.bias', 'visual.transformer.resblocks.13.ln_1.weight', 'visual.transformer.resblocks.13.ln_2.bias', 'visual.transformer.resblocks.13.ln_2.weight', 'visual.transformer.resblocks.13.mlp.c_fc.bias', 'visual.transformer.resblocks.13.mlp.c_fc.weight', 'visual.transformer.resblocks.13.mlp.c_proj.bias', 'visual.transformer.resblocks.13.mlp.c_proj.weight', 'visual.transformer.resblocks.14.attn.in_proj_bias', 'visual.transformer.resblocks.14.attn.in_proj_weight', 'visual.transformer.resblocks.14.attn.out_proj.bias', 'visual.transformer.resblocks.14.attn.out_proj.weight', 'visual.transformer.resblocks.14.ln_1.bias', 'visual.transformer.resblocks.14.ln_1.weight', 'visual.transformer.resblocks.14.ln_2.bias', 'visual.transformer.resblocks.14.ln_2.weight', 'visual.transformer.resblocks.14.mlp.c_fc.bias', 'visual.transformer.resblocks.14.mlp.c_fc.weight', 'visual.transformer.resblocks.14.mlp.c_proj.bias', 'visual.transformer.resblocks.14.mlp.c_proj.weight', 'visual.transformer.resblocks.15.attn.in_proj_bias', 'visual.transformer.resblocks.15.attn.in_proj_weight', 'visual.transformer.resblocks.15.attn.out_proj.bias', 'visual.transformer.resblocks.15.attn.out_proj.weight', 'visual.transformer.resblocks.15.ln_1.bias', 'visual.transformer.resblocks.15.ln_1.weight', 'visual.transformer.resblocks.15.ln_2.bias', 'visual.transformer.resblocks.15.ln_2.weight', 'visual.transformer.resblocks.15.mlp.c_fc.bias', 'visual.transformer.resblocks.15.mlp.c_fc.weight', 'visual.transformer.resblocks.15.mlp.c_proj.bias', 'visual.transformer.resblocks.15.mlp.c_proj.weight', 'visual.transformer.resblocks.16.attn.in_proj_bias', 'visual.transformer.resblocks.16.attn.in_proj_weight', 'visual.transformer.resblocks.16.attn.out_proj.bias', 'visual.transformer.resblocks.16.attn.out_proj.weight', 'visual.transformer.resblocks.16.ln_1.bias', 'visual.transformer.resblocks.16.ln_1.weight', 'visual.transformer.resblocks.16.ln_2.bias', 'visual.transformer.resblocks.16.ln_2.weight', 'visual.transformer.resblocks.16.mlp.c_fc.bias', 'visual.transformer.resblocks.16.mlp.c_fc.weight', 'visual.transformer.resblocks.16.mlp.c_proj.bias', 'visual.transformer.resblocks.16.mlp.c_proj.weight', 'visual.transformer.resblocks.17.attn.in_proj_bias', 'visual.transformer.resblocks.17.attn.in_proj_weight', 'visual.transformer.resblocks.17.attn.out_proj.bias', 'visual.transformer.resblocks.17.attn.out_proj.weight', 'visual.transformer.resblocks.17.ln_1.bias', 'visual.transformer.resblocks.17.ln_1.weight', 'visual.transformer.resblocks.17.ln_2.bias', 'visual.transformer.resblocks.17.ln_2.weight', 'visual.transformer.resblocks.17.mlp.c_fc.bias', 'visual.transformer.resblocks.17.mlp.c_fc.weight', 'visual.transformer.resblocks.17.mlp.c_proj.bias', 'visual.transformer.resblocks.17.mlp.c_proj.weight', 'visual.transformer.resblocks.18.attn.in_proj_bias', 'visual.transformer.resblocks.18.attn.in_proj_weight', 'visual.transformer.resblocks.18.attn.out_proj.bias', 'visual.transformer.resblocks.18.attn.out_proj.weight', 'visual.transformer.resblocks.18.ln_1.bias', 'visual.transformer.resblocks.18.ln_1.weight', 'visual.transformer.resblocks.18.ln_2.bias', 'visual.transformer.resblocks.18.ln_2.weight', 'visual.transformer.resblocks.18.mlp.c_fc.bias', 'visual.transformer.resblocks.18.mlp.c_fc.weight', 'visual.transformer.resblocks.18.mlp.c_proj.bias', 'visual.transformer.resblocks.18.mlp.c_proj.weight', 'visual.transformer.resblocks.19.attn.in_proj_bias', 'visual.transformer.resblocks.19.attn.in_proj_weight', 'visual.transformer.resblocks.19.attn.out_proj.bias', 'visual.transformer.resblocks.19.attn.out_proj.weight', 'visual.transformer.resblocks.19.ln_1.bias', 'visual.transformer.resblocks.19.ln_1.weight', 'visual.transformer.resblocks.19.ln_2.bias', 'visual.transformer.resblocks.19.ln_2.weight', 'visual.transformer.resblocks.19.mlp.c_fc.bias', 'visual.transformer.resblocks.19.mlp.c_fc.weight', 'visual.transformer.resblocks.19.mlp.c_proj.bias', 'visual.transformer.resblocks.19.mlp.c_proj.weight', 'visual.transformer.resblocks.2.attn.in_proj_bias', 'visual.transformer.resblocks.2.attn.in_proj_weight', 'visual.transformer.resblocks.2.attn.out_proj.bias', 'visual.transformer.resblocks.2.attn.out_proj.weight', 'visual.transformer.resblocks.2.ln_1.bias', 'visual.transformer.resblocks.2.ln_1.weight', 'visual.transformer.resblocks.2.ln_2.bias', 'visual.transformer.resblocks.2.ln_2.weight', 'visual.transformer.resblocks.2.mlp.c_fc.bias', 'visual.transformer.resblocks.2.mlp.c_fc.weight', 'visual.transformer.resblocks.2.mlp.c_proj.bias', 'visual.transformer.resblocks.2.mlp.c_proj.weight', 'visual.transformer.resblocks.20.attn.in_proj_bias', 'visual.transformer.resblocks.20.attn.in_proj_weight', 'visual.transformer.resblocks.20.attn.out_proj.bias', 'visual.transformer.resblocks.20.attn.out_proj.weight', 'visual.transformer.resblocks.20.ln_1.bias', 'visual.transformer.resblocks.20.ln_1.weight', 'visual.transformer.resblocks.20.ln_2.bias', 'visual.transformer.resblocks.20.ln_2.weight', 'visual.transformer.resblocks.20.mlp.c_fc.bias', 'visual.transformer.resblocks.20.mlp.c_fc.weight', 'visual.transformer.resblocks.20.mlp.c_proj.bias', 'visual.transformer.resblocks.20.mlp.c_proj.weight', 'visual.transformer.resblocks.21.attn.in_proj_bias', 'visual.transformer.resblocks.21.attn.in_proj_weight', 'visual.transformer.resblocks.21.attn.out_proj.bias', 'visual.transformer.resblocks.21.attn.out_proj.weight', 'visual.transformer.resblocks.21.ln_1.bias', 'visual.transformer.resblocks.21.ln_1.weight', 'visual.transformer.resblocks.21.ln_2.bias', 'visual.transformer.resblocks.21.ln_2.weight', 'visual.transformer.resblocks.21.mlp.c_fc.bias', 'visual.transformer.resblocks.21.mlp.c_fc.weight', 'visual.transformer.resblocks.21.mlp.c_proj.bias', 'visual.transformer.resblocks.21.mlp.c_proj.weight', 'visual.transformer.resblocks.22.attn.in_proj_bias', 'visual.transformer.resblocks.22.attn.in_proj_weight', 'visual.transformer.resblocks.22.attn.out_proj.bias', 'visual.transformer.resblocks.22.attn.out_proj.weight', 'visual.transformer.resblocks.22.ln_1.bias', 'visual.transformer.resblocks.22.ln_1.weight', 'visual.transformer.resblocks.22.ln_2.bias', 'visual.transformer.resblocks.22.ln_2.weight', 'visual.transformer.resblocks.22.mlp.c_fc.bias', 'visual.transformer.resblocks.22.mlp.c_fc.weight', 'visual.transformer.resblocks.22.mlp.c_proj.bias', 'visual.transformer.resblocks.22.mlp.c_proj.weight', 'visual.transformer.resblocks.23.attn.in_proj_bias', 'visual.transformer.resblocks.23.attn.in_proj_weight', 'visual.transformer.resblocks.23.attn.out_proj.bias', 'visual.transformer.resblocks.23.attn.out_proj.weight', 'visual.transformer.resblocks.23.ln_1.bias', 'visual.transformer.resblocks.23.ln_1.weight', 'visual.transformer.resblocks.23.ln_2.bias', 'visual.transformer.resblocks.23.ln_2.weight', 'visual.transformer.resblocks.23.mlp.c_fc.bias', 'visual.transformer.resblocks.23.mlp.c_fc.weight', 'visual.transformer.resblocks.23.mlp.c_proj.bias', 'visual.transformer.resblocks.23.mlp.c_proj.weight', 'visual.transformer.resblocks.3.attn.in_proj_bias', 'visual.transformer.resblocks.3.attn.in_proj_weight', 'visual.transformer.resblocks.3.attn.out_proj.bias', 'visual.transformer.resblocks.3.attn.out_proj.weight', 'visual.transformer.resblocks.3.ln_1.bias', 'visual.transformer.resblocks.3.ln_1.weight', 'visual.transformer.resblocks.3.ln_2.bias', 'visual.transformer.resblocks.3.ln_2.weight', 'visual.transformer.resblocks.3.mlp.c_fc.bias', 'visual.transformer.resblocks.3.mlp.c_fc.weight', 'visual.transformer.resblocks.3.mlp.c_proj.bias', 'visual.transformer.resblocks.3.mlp.c_proj.weight', 'visual.transformer.resblocks.4.attn.in_proj_bias', 'visual.transformer.resblocks.4.attn.in_proj_weight', 'visual.transformer.resblocks.4.attn.out_proj.bias', 'visual.transformer.resblocks.4.attn.out_proj.weight', 'visual.transformer.resblocks.4.ln_1.bias', 'visual.transformer.resblocks.4.ln_1.weight', 'visual.transformer.resblocks.4.ln_2.bias', 'visual.transformer.resblocks.4.ln_2.weight', 'visual.transformer.resblocks.4.mlp.c_fc.bias', 'visual.transformer.resblocks.4.mlp.c_fc.weight', 'visual.transformer.resblocks.4.mlp.c_proj.bias', 'visual.transformer.resblocks.4.mlp.c_proj.weight', 'visual.transformer.resblocks.5.attn.in_proj_bias', 'visual.transformer.resblocks.5.attn.in_proj_weight', 'visual.transformer.resblocks.5.attn.out_proj.bias', 'visual.transformer.resblocks.5.attn.out_proj.weight', 'visual.transformer.resblocks.5.ln_1.bias', 'visual.transformer.resblocks.5.ln_1.weight', 'visual.transformer.resblocks.5.ln_2.bias', 'visual.transformer.resblocks.5.ln_2.weight', 'visual.transformer.resblocks.5.mlp.c_fc.bias', 'visual.transformer.resblocks.5.mlp.c_fc.weight', 'visual.transformer.resblocks.5.mlp.c_proj.bias', 'visual.transformer.resblocks.5.mlp.c_proj.weight', 'visual.transformer.resblocks.6.attn.in_proj_bias', 'visual.transformer.resblocks.6.attn.in_proj_weight', 'visual.transformer.resblocks.6.attn.out_proj.bias', 'visual.transformer.resblocks.6.attn.out_proj.weight', 'visual.transformer.resblocks.6.ln_1.bias', 'visual.transformer.resblocks.6.ln_1.weight', 'visual.transformer.resblocks.6.ln_2.bias', 'visual.transformer.resblocks.6.ln_2.weight', 'visual.transformer.resblocks.6.mlp.c_fc.bias', 'visual.transformer.resblocks.6.mlp.c_fc.weight', 'visual.transformer.resblocks.6.mlp.c_proj.bias', 'visual.transformer.resblocks.6.mlp.c_proj.weight', 'visual.transformer.resblocks.7.attn.in_proj_bias', 'visual.transformer.resblocks.7.attn.in_proj_weight', 'visual.transformer.resblocks.7.attn.out_proj.bias', 'visual.transformer.resblocks.7.attn.out_proj.weight', 'visual.transformer.resblocks.7.ln_1.bias', 'visual.transformer.resblocks.7.ln_1.weight', 'visual.transformer.resblocks.7.ln_2.bias', 'visual.transformer.resblocks.7.ln_2.weight', 'visual.transformer.resblocks.7.mlp.c_fc.bias', 'visual.transformer.resblocks.7.mlp.c_fc.weight', 'visual.transformer.resblocks.7.mlp.c_proj.bias', 'visual.transformer.resblocks.7.mlp.c_proj.weight', 'visual.transformer.resblocks.8.attn.in_proj_bias', 'visual.transformer.resblocks.8.attn.in_proj_weight', 'visual.transformer.resblocks.8.attn.out_proj.bias', 'visual.transformer.resblocks.8.attn.out_proj.weight', 'visual.transformer.resblocks.8.ln_1.bias', 'visual.transformer.resblocks.8.ln_1.weight', 'visual.transformer.resblocks.8.ln_2.bias', 'visual.transformer.resblocks.8.ln_2.weight', 'visual.transformer.resblocks.8.mlp.c_fc.bias', 'visual.transformer.resblocks.8.mlp.c_fc.weight', 'visual.transformer.resblocks.8.mlp.c_proj.bias', 'visual.transformer.resblocks.8.mlp.c_proj.weight', 'visual.transformer.resblocks.9.attn.in_proj_bias', 'visual.transformer.resblocks.9.attn.in_proj_weight', 'visual.transformer.resblocks.9.attn.out_proj.bias', 'visual.transformer.resblocks.9.attn.out_proj.weight', 'visual.transformer.resblocks.9.ln_1.bias', 'visual.transformer.resblocks.9.ln_1.weight', 'visual.transformer.resblocks.9.ln_2.bias', 'visual.transformer.resblocks.9.ln_2.weight', 'visual.transformer.resblocks.9.mlp.c_fc.bias', 'visual.transformer.resblocks.9.mlp.c_fc.weight', 'visual.transformer.resblocks.9.mlp.c_proj.bias', 'visual.transformer.resblocks.9.mlp.c_proj.weight'])
2. model build: 1.51 sec
3. load_state_dict: 0.24 sec
Missing: 0 Unexpected: 0
Missing sample: []
Unexpected sample: []
4. to device: 0.34 sec
Done
Image size: 336
Context length: 32
Model device: cuda:0
Model dtype: torch.float32
CLIP(
(token_embedding): Embedding(49408, 1024)
(transformer): Transformer(
(resblocks): ModuleList(
(0-23): 24 x ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
)
(ls_1): Identity()
(ls_2): Identity()
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(drop_path1): Identity()
(drop_path2): Identity()
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
(ln_final): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(visual): VisionTransformer(
(conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(transformer): Transformer(
(resblocks): ModuleList(
(0-23): 24 x ResidualAttentionBlock(
(attn): SelfAttention(
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(ls_1): Identity()
(ls_2): Identity()
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(drop_path1): Identity()
(drop_path2): Identity()
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
(attn_pool): AttentionPooling(
(attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
)
(layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
Datasets#
Load gylang reid set
df = pd.read_csv(METADATA_CSV)
print(df.head())
print(df.columns.tolist())
print(df["final_split"].value_counts(dropna=False))
required_cols = ["path", "vid", "camid", "final_split"]
missing = [c for c in required_cols if c not in df.columns]
assert not missing, f"Missing columns: {missing}"
df = df[required_cols].copy()
df["path"] = df["path"].astype(str)
df["vid"] = df["vid"].astype(int)
df["camid"] = df["camid"].astype(int)
query_df = df[df["final_split"] == "query"].copy()
gallery_df = df[df["final_split"] == "gallery"].copy()
print("query:", len(query_df))
print("gallery:", len(gallery_df))
print("query vids:", query_df["vid"].nunique())
print("gallery vids:", gallery_df["vid"].nunique())
path \
0 /home/user/data/datasets/VeRi_Self/reid/cam_0_...
1 /home/user/data/datasets/VeRi_Self/reid/cam_0_...
2 /home/user/data/datasets/VeRi_Self/reid/cam_0_...
3 /home/user/data/datasets/VeRi_Self/reid/cam_0_...
4 /home/user/data/datasets/VeRi_Self/reid/cam_0_...
rel_path \
0 cam_0_001/rank_01_frame_00000102_score_0.898.jpg
1 cam_0_001/rank_02_frame_00000070_score_0.811.jpg
2 cam_0_001/rank_03_frame_00000040_score_0.743.jpg
3 cam_0_002/rank_01_frame_00000206_score_0.898.jpg
4 cam_0_002/rank_02_frame_00000191_score_0.824.jpg
file_name folder vid camid rank \
0 rank_01_frame_00000102_score_0.898.jpg cam_0_001 1 0 1
1 rank_02_frame_00000070_score_0.811.jpg cam_0_001 1 0 2
2 rank_03_frame_00000040_score_0.743.jpg cam_0_001 1 0 3
3 rank_01_frame_00000206_score_0.898.jpg cam_0_002 2 0 1
4 rank_02_frame_00000191_score_0.824.jpg cam_0_002 2 0 2
frame_idx score ext split_initial final_split
0 102 0.898 jpg query query
1 70 0.811 jpg gallery gallery
2 40 0.743 jpg gallery gallery
3 206 0.898 jpg query query
4 191 0.824 jpg gallery gallery
['path', 'rel_path', 'file_name', 'folder', 'vid', 'camid', 'rank', 'frame_idx', 'score', 'ext', 'split_initial', 'final_split']
final_split
gallery 58
query 29
Name: count, dtype: int64
query: 29
gallery: 58
query vids: 15
gallery vids: 15
load distractors 100k
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
distractor_paths = []
for p in DISTRACTOR_ROOT.rglob("*"):
if p.is_file() and p.suffix.lower() in IMAGE_EXTS:
distractor_paths.append(str(p.resolve()))
distractor_paths = sorted(distractor_paths)
print("num distractors:", len(distractor_paths))
print(distractor_paths[:5])
num distractors: 100000
['/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/000a3ff7-d31c-4fe1-bdf0-020b7d7a93a4.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/0010ded2-9ec0-4e78-b7d4-d64f4c1c9303.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/00148612-07d9-4e7d-8552-04be8af22435.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/001497dd-6933-4f52-9291-fa32283ac95f.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/001a7312-d74e-46bf-b5f6-b91e49558feb.jpg']
merge dataframes
distractor_df = pd.DataFrame({
"path": distractor_paths,
"vid": -1,
"camid": -1,
"final_split": "gallery_distractor",
})
gallery_plus_df = pd.concat(
[
gallery_df[["path", "vid", "camid", "final_split"]],
distractor_df[["path", "vid", "camid", "final_split"]],
],
ignore_index=True,
)
gallery_plus_df = gallery_plus_df.reset_index(drop=True)
query_df = query_df.reset_index(drop=True)
print("gallery original:", len(gallery_df))
print("gallery + distractor:", len(gallery_plus_df))
print(gallery_plus_df["final_split"].value_counts())
gallery original: 58
gallery + distractor: 100058
final_split
gallery_distractor 100000
gallery 58
Name: count, dtype: int64
create unique imageid
query_df["image_id"] = ["query_" + str(i).zfill(8) for i in range(len(query_df))]
gallery_plus_df["image_id"] = ["gallery_" + str(i).zfill(8) for i in range(len(gallery_plus_df))]
query_df.head()
| path | vid | camid | final_split | image_id | |
|---|---|---|---|---|---|
| 0 | /home/user/data/datasets/VeRi_Self/reid/cam_0_... | 1 | 0 | query | query_00000000 |
| 1 | /home/user/data/datasets/VeRi_Self/reid/cam_0_... | 2 | 0 | query | query_00000001 |
| 2 | /home/user/data/datasets/VeRi_Self/reid/cam_0_... | 3 | 0 | query | query_00000002 |
| 3 | /home/user/data/datasets/VeRi_Self/reid/cam_0_... | 4 | 0 | query | query_00000003 |
| 4 | /home/user/data/datasets/VeRi_Self/reid/cam_0_... | 5 | 0 | query | query_00000004 |
Dataclass for both CLIP and ReID#
class ImageDataset(Dataset):
def __init__(self, df, preprocess):
self.df = df.reset_index(drop=True)
self.preprocess = preprocess
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img = Image.open(row["path"]).convert("RGB")
img = self.preprocess(img)
return {
"image": img,
"vid": int(row["vid"]),
"camid": int(row["camid"]),
"path": row["path"],
"image_id": row["image_id"],
"final_split": row["final_split"],
}
def collate_fn(batch):
images = torch.stack([x["image"] for x in batch])
return {
"images": images,
"vids": np.array([x["vid"] for x in batch], dtype=np.int64),
"camids": np.array([x["camid"] for x in batch], dtype=np.int64),
"paths": [x["path"] for x in batch],
"image_ids": [x["image_id"] for x in batch],
"final_splits": [x["final_split"] for x in batch],
}
DataLoaders#
BATCH_SIZE = 128
NUM_WORKERS = 8
query_dataset = ImageDataset(query_df, preprocess)
gallery_dataset = ImageDataset(gallery_plus_df, preprocess)
query_loader = DataLoader(
query_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
collate_fn=collate_fn,
)
gallery_loader = DataLoader(
gallery_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
collate_fn=collate_fn,
)
Embedding#
def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
return x / x.norm(dim=-1, keepdim=True).clamp(min=eps)
@torch.inference_mode()
def encode_loader(loader, model, device=DEVICE, dtype=torch.bfloat16):
all_embeddings = []
rows = []
start = time.time()
for batch in tqdm(loader):
images = batch["images"].to(device, non_blocking=True)
with torch.autocast(
device_type="cuda",
dtype=dtype,
enabled=(device == "cuda"),
):
features = model.encode_image(images)
features = features.float()
features = l2_normalize(features)
emb = features.cpu().numpy().astype("float32")
all_embeddings.append(emb)
batch_size = emb.shape[0]
for i in range(batch_size):
rows.append({
"image_id": batch["image_ids"][i],
"path": batch["paths"][i],
"vid": int(batch["vids"][i]),
"camid": int(batch["camids"][i]),
"final_split": batch["final_splits"][i],
})
embeddings = np.vstack(all_embeddings)
elapsed = time.time() - start
print("embeddings:", embeddings.shape)
print("elapsed min:", round(elapsed / 60, 2))
print("images/sec:", round(len(rows) / elapsed, 2))
meta = pd.DataFrame(rows)
return embeddings, meta
Encode Gallery (reid) + distractors#
GALLERY_EMBEDDINGS_PATH_PE = EMB_DIR / "gallery_plus_embeddings_PE.npy"
GALLERY_METADATA_PATH_PE = EMB_DIR / "gallery_plus_metadata_PE.parquet"
gallery_embeddings, gallery_meta = encode_loader(
gallery_loader,
model,
device=DEVICE,
dtype=DTYPE,
)
np.save(GALLERY_EMBEDDINGS_PATH_PE, gallery_embeddings)
gallery_meta.to_parquet(GALLERY_METADATA_PATH_PE, index=False)
print(gallery_embeddings.shape)
print(gallery_meta.head())
100%|██████████| 782/782 [22:02<00:00, 1.69s/it]
embeddings: (100058, 1024)
elapsed min: 22.05
images/sec: 75.62
(100058, 1024)
image_id path vid \
0 gallery_00000000 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 1
1 gallery_00000001 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 1
2 gallery_00000002 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 2
3 gallery_00000003 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 2
4 gallery_00000004 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 3
camid final_split
0 0 gallery
1 0 gallery
2 0 gallery
3 0 gallery
4 0 gallery
embeddings for query set
QUERY_EMBEDDINGS_PATH_PE = EMB_DIR / "query_embeddings_PE.npy"
QUERY_METADATA_PATH_PE = EMB_DIR / "query_metadata_PE.parquet"
query_embeddings, query_meta = encode_loader(
query_loader,
model,
device=DEVICE,
dtype=DTYPE,
)
np.save(QUERY_EMBEDDINGS_PATH_PE, query_embeddings)
query_meta.to_parquet(QUERY_METADATA_PATH_PE, index=False)
print(query_embeddings.shape)
print(query_meta.head())
100%|██████████| 1/1 [00:00<00:00, 1.13it/s]
embeddings: (29, 1024)
elapsed min: 0.01
images/sec: 32.77
(29, 1024)
image_id path vid \
0 query_00000000 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 1
1 query_00000001 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 2
2 query_00000002 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 3
3 query_00000003 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 4
4 query_00000004 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 5
camid final_split
0 0 query
1 0 query
2 0 query
3 0 query
4 0 query
Load saved embeddings
gallery_embeddings = np.load(GALLERY_EMBEDDINGS_PATH_PE)
gallery_meta = pd.read_parquet(GALLERY_METADATA_PATH_PE)
query_embeddings = np.load(QUERY_EMBEDDINGS_PATH_PE)
query_meta = pd.read_parquet(QUERY_METADATA_PATH_PE)
print("gallery:", gallery_embeddings.shape, gallery_meta.shape)
print("query:", query_embeddings.shape, query_meta.shape)
gallery: (100058, 1024) (100058, 5)
query: (29, 1024) (29, 5)
Text Encoding#
@torch.inference_mode()
def encode_texts(texts, model, tokenizer, device=DEVICE, dtype=torch.bfloat16):
if isinstance(texts, str):
texts = [texts]
text_tokens = tokenizer(texts).to(device)
with torch.autocast(
device_type="cuda",
dtype=dtype,
enabled=(device == "cuda"),
):
text_features = model.encode_text(text_tokens)
text_features = text_features.float()
text_features = l2_normalize(text_features)
return text_features.cpu().numpy().astype("float32")
Text to Image search
def text_to_image_search(query, top_k=5):
text_emb = encode_texts(query, model, tokenizer)[0]
scores = gallery_embeddings @ text_emb
top_indices = np.argsort(scores)[::-1][:top_k]
top_scores = scores[top_indices]
mean = scores.mean()
std = scores.std() + 1e-12
results = gallery_meta.iloc[top_indices].copy()
results["query"] = query
results["rank"] = np.arange(1, len(results) + 1)
results["cosine_score"] = top_scores
results["z_score"] = (top_scores - mean) / std
results["gap_from_top1"] = top_scores[0] - top_scores
return results.reset_index(drop=True), scores
Show results
results, scores = text_to_image_search("white car with black bonnet", top_k=5)
# display(results)
show_search_results(results, title="white car with black bonnet")
Evals#
review_queries = [
"white car with sunroof",
"white car with black roof",
"SUV with roof rails",
"truck with open cargo bed",
"box truck with cargo container",
"white van with yellow stripes",
"pink sports car",
"green taxi with roof sign",
"white car with flames painted on the side",
"ambulance with red cross",
"vehicle with zebra stripes",
"car carrying bicycle on roof",
]
queries_df = pd.DataFrame({
"query_id": [f"q{i+1:03d}" for i in range(len(review_queries))],
"query": review_queries,
})
all_review_results = []
for _, row in tqdm(queries_df.iterrows(), total=len(queries_df)):
query_id = row["query_id"]
query = row["query"]
results, _ = text_to_image_search(query, top_k=5)
results["query_id"] = query_id
all_review_results.append(results)
review_results = pd.concat(all_review_results, ignore_index=True)
review_results["human_label"] = -1
review_results["human_notes"] = ""
review_results = review_results[
[
"query_id",
"query",
"rank",
"image_id",
"path",
"vid",
"camid",
"final_split",
"cosine_score",
"z_score",
"gap_from_top1",
"human_label",
"human_notes",
]
]
review_results.to_csv(OUTPUT_DIR / "text_review_top5_unlabeled.csv", index=False)
review_results.head()
100%|██████████| 12/12 [00:00<00:00, 57.15it/s]
| query_id | query | rank | image_id | path | vid | camid | final_split | cosine_score | z_score | gap_from_top1 | human_label | human_notes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | q001 | white car with sunroof | 1 | gallery_00032241 | /home/user/data/datasets/VeRi_Self/PKU_Vehicle... | -1 | -1 | gallery_distractor | 0.286042 | 2.644217 | 0.000000 | -1 | |
| 1 | q001 | white car with sunroof | 2 | gallery_00045133 | /home/user/data/datasets/VeRi_Self/PKU_Vehicle... | -1 | -1 | gallery_distractor | 0.284771 | 2.609696 | 0.001271 | -1 | |
| 2 | q001 | white car with sunroof | 3 | gallery_00035865 | /home/user/data/datasets/VeRi_Self/PKU_Vehicle... | -1 | -1 | gallery_distractor | 0.284661 | 2.606714 | 0.001381 | -1 | |
| 3 | q001 | white car with sunroof | 4 | gallery_00040658 | /home/user/data/datasets/VeRi_Self/PKU_Vehicle... | -1 | -1 | gallery_distractor | 0.284134 | 2.592382 | 0.001909 | -1 | |
| 4 | q001 | white car with sunroof | 5 | gallery_00008036 | /home/user/data/datasets/VeRi_Self/PKU_Vehicle... | -1 | -1 | gallery_distractor | 0.284064 | 2.590488 | 0.001978 | -1 |
labelling#
Was planning to do manual labelling for threshold calibration… say for each output rank, feed in below labels
2 = correct / strong match
1 = partial / acceptable
0 = wrong
-1 = not reviewed
so for query 0 we could say [2, 2, 1, 0, 0] that would mean
Rank 1 → 2 = correct / strong match
Rank 2 → 2 = correct / strong match
Rank 3 → 1 = partial / acceptable
Rank 4 → 0 = wrong
Rank 5 → 0 = wrong
and then write this back into the metadata file under human_label…
No time, will skip this for now
review_results_labeled = review_results.copy()
def next_unreviewed_queries(n=5):
unreviewed = review_results_labeled[
review_results_labeled["human_label"] == -1
]["query_id"].unique()
return list(unreviewed[:n])
qids = next_unreviewed_queries(12)
show_queries_for_labeling(qids)
Image to Image Search (simulate ReID)#
using a CLIP model for Image to Image search, simulating ReID
def image_to_image_search(query_index, top_k=10, remove_same_cam=True):
q_emb = query_embeddings[query_index]
q_row = query_meta.iloc[query_index]
scores = gallery_embeddings @ q_emb
candidate_df = gallery_meta.copy()
candidate_df["score"] = scores
# Optional ReID protocol: remove same vid + same cam junk
if remove_same_cam:
junk_mask = (
(candidate_df["vid"] == q_row["vid"]) &
(candidate_df["camid"] == q_row["camid"])
)
candidate_df = candidate_df[~junk_mask].copy()
candidate_df = candidate_df.sort_values("score", ascending=False).head(top_k)
candidate_df["rank"] = np.arange(1, len(candidate_df) + 1)
candidate_df["query_path"] = q_row["path"]
candidate_df["query_vid"] = q_row["vid"]
candidate_df["query_camid"] = q_row["camid"]
candidate_df["is_correct_vid"] = candidate_df["vid"] == q_row["vid"]
return candidate_df.reset_index(drop=True)
reid_results = image_to_image_search(0, top_k=10)
reid_results[["rank", "score", "vid", "camid", "is_correct_vid", "path"]]
| rank | score | vid | camid | is_correct_vid | path | |
|---|---|---|---|---|---|---|
| 0 | 1 | 0.842100 | 1 | 1 | True | /home/user/data/datasets/VeRi_Self/reid/cam_1_... |
| 1 | 2 | 0.842017 | 1 | 1 | True | /home/user/data/datasets/VeRi_Self/reid/cam_1_... |
| 2 | 3 | 0.832953 | 8 | 1 | False | /home/user/data/datasets/VeRi_Self/reid/cam_1_... |
| 3 | 4 | 0.811326 | 8 | 1 | False | /home/user/data/datasets/VeRi_Self/reid/cam_1_... |
| 4 | 5 | 0.781126 | 8 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
| 5 | 6 | 0.778882 | 3 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
| 6 | 7 | 0.777549 | 3 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
| 7 | 8 | 0.764441 | 6 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
| 8 | 9 | 0.753039 | 4 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
| 9 | 10 | 0.740953 | 8 | 0 | False | /home/user/data/datasets/VeRi_Self/reid/cam_0_... |
query_index = 0
reid_results = image_to_image_search(query_index, top_k=5)
show_image_to_image_results(query_index, reid_results)
def evaluate_reid_rank_k(max_queries=None, top_k=10):
n_queries = len(query_meta) if max_queries is None else min(max_queries, len(query_meta))
rank1_hits = 0
rank5_hits = 0
rank10_hits = 0
valid_queries = 0
for qi in tqdm(range(n_queries)):
q_vid = query_meta.iloc[qi]["vid"]
results = image_to_image_search(qi, top_k=top_k, remove_same_cam=True)
if len(results) == 0:
continue
valid_queries += 1
correct = results["vid"].values == q_vid
rank1_hits += int(correct[:1].any())
rank5_hits += int(correct[:5].any())
rank10_hits += int(correct[:10].any())
return {
"valid_queries": valid_queries,
"rank1": rank1_hits / valid_queries,
"rank5": rank5_hits / valid_queries,
"rank10": rank10_hits / valid_queries,
}
evaluate_reid_rank_k(max_queries=100, top_k=10)
100%|██████████| 29/29 [00:00<00:00, 48.03it/s]
{'valid_queries': 29,
'rank1': 0.896551724137931,
'rank5': 0.9655172413793104,
'rank10': 1.0}
for query_index in range(29):
reid_results = image_to_image_search(query_index, top_k=5)
show_image_to_image_results(query_index, reid_results)