◉ Perception Encorder (CLIP) for ReID#

Tested on 5090 laptop GPU. Installation steps for PE-Core-L14-336 checkpoitns

git clone https://github.com/facebookresearch/perception_models.git
cd perception_models

conda create -n pe_test python=3.12 -y
conda activate pe_test

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

pip install pillow tqdm numpy pandas scikit-learn matplotlib
pip install huggingface_hub
pip install -e .

download model locally

hf download facebook/PE-Core-L14-336 \
  --local-dir models   
from pathlib import Path
import time
import json
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from helpers import *
ROOT = Path("/home/user/data/datasets/VeRi_Self")
REID_ROOT = ROOT / "reid"
DISTRACTOR_ROOT = ROOT / "PKU_Vehicle"

METADATA_CSV = REID_ROOT / "all_metadata.csv"

EMB_DIR = Path("outputs/embeddings")
EMB_DIR.mkdir(parents=True, exist_ok=True)

OUTPUT_DIR = Path("outputs/evals")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
print("Device:", DEVICE)
print("CUDA device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
Device: cuda
CUDA device: NVIDIA GeForce RTX 5090 Laptop GPU
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
from pathlib import Path

MODEL_NAME = "PE-Core-L14-336"
LOCAL_MODEL_DIR = Path("models")

ckpt_path = LOCAL_MODEL_DIR / "PE-Core-L14-336.pt"
config_path = LOCAL_MODEL_DIR / "config.yaml"

print(ckpt_path.exists(), ckpt_path)
print(config_path.exists(), config_path)
print("Checkpoint size GB:", round(ckpt_path.stat().st_size / 1024**3, 2))
True models/PE-Core-L14-336.pt
True models/config.yaml
Checkpoint size GB: 2.5

Load the PE model

import time
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms

t = time.time()
ckpt = torch.load(ckpt_path, map_location="cpu")
print("1. torch.load:", round(time.time() - t, 2), "sec")

if isinstance(ckpt, dict):
    print("Checkpoint keys:", ckpt.keys())

t = time.time()
model = pe.CLIP.from_config(MODEL_NAME, pretrained=False)
print("2. model build:", round(time.time() - t, 2), "sec")

if isinstance(ckpt, dict):
    if "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    elif "model" in ckpt:
        state_dict = ckpt["model"]
    elif "model_state_dict" in ckpt:
        state_dict = ckpt["model_state_dict"]
    else:
        state_dict = ckpt
else:
    raise ValueError("Unexpected checkpoint format")

clean_state_dict = {}
for k, v in state_dict.items():
    new_k = k.replace("module.", "").replace("model.", "")
    clean_state_dict[new_k] = v

t = time.time()
missing, unexpected = model.load_state_dict(clean_state_dict, strict=False)
print("3. load_state_dict:", round(time.time() - t, 2), "sec")
print("Missing:", len(missing), "Unexpected:", len(unexpected))
print("Missing sample:", missing[:5])
print("Unexpected sample:", unexpected[:5])

t = time.time()
model = model.to(DEVICE).eval()
if DEVICE == "cuda":
    torch.cuda.synchronize()
print("4. to device:", round(time.time() - t, 2), "sec")

preprocess = transforms.get_image_transform(model.image_size)
tokenizer = transforms.get_text_tokenizer(model.context_length)

print("Done")
print("Image size:", model.image_size)
print("Context length:", model.context_length)

p = next(model.parameters())

print("Model device:", p.device)
print("Model dtype:", p.dtype)

model
1. torch.load: 1.04 sec
Checkpoint keys: dict_keys(['ln_final.bias', 'ln_final.weight', 'logit_scale', 'positional_embedding', 'text_projection', 'token_embedding.weight', 'transformer.resblocks.0.attn.in_proj_bias', 'transformer.resblocks.0.attn.in_proj_weight', 'transformer.resblocks.0.attn.out_proj.bias', 'transformer.resblocks.0.attn.out_proj.weight', 'transformer.resblocks.0.ln_1.bias', 'transformer.resblocks.0.ln_1.weight', 'transformer.resblocks.0.ln_2.bias', 'transformer.resblocks.0.ln_2.weight', 'transformer.resblocks.0.mlp.c_fc.bias', 'transformer.resblocks.0.mlp.c_fc.weight', 'transformer.resblocks.0.mlp.c_proj.bias', 'transformer.resblocks.0.mlp.c_proj.weight', 'transformer.resblocks.1.attn.in_proj_bias', 'transformer.resblocks.1.attn.in_proj_weight', 'transformer.resblocks.1.attn.out_proj.bias', 'transformer.resblocks.1.attn.out_proj.weight', 'transformer.resblocks.1.ln_1.bias', 'transformer.resblocks.1.ln_1.weight', 'transformer.resblocks.1.ln_2.bias', 'transformer.resblocks.1.ln_2.weight', 'transformer.resblocks.1.mlp.c_fc.bias', 'transformer.resblocks.1.mlp.c_fc.weight', 'transformer.resblocks.1.mlp.c_proj.bias', 'transformer.resblocks.1.mlp.c_proj.weight', 'transformer.resblocks.10.attn.in_proj_bias', 'transformer.resblocks.10.attn.in_proj_weight', 'transformer.resblocks.10.attn.out_proj.bias', 'transformer.resblocks.10.attn.out_proj.weight', 'transformer.resblocks.10.ln_1.bias', 'transformer.resblocks.10.ln_1.weight', 'transformer.resblocks.10.ln_2.bias', 'transformer.resblocks.10.ln_2.weight', 'transformer.resblocks.10.mlp.c_fc.bias', 'transformer.resblocks.10.mlp.c_fc.weight', 'transformer.resblocks.10.mlp.c_proj.bias', 'transformer.resblocks.10.mlp.c_proj.weight', 'transformer.resblocks.11.attn.in_proj_bias', 'transformer.resblocks.11.attn.in_proj_weight', 'transformer.resblocks.11.attn.out_proj.bias', 'transformer.resblocks.11.attn.out_proj.weight', 'transformer.resblocks.11.ln_1.bias', 'transformer.resblocks.11.ln_1.weight', 'transformer.resblocks.11.ln_2.bias', 'transformer.resblocks.11.ln_2.weight', 'transformer.resblocks.11.mlp.c_fc.bias', 'transformer.resblocks.11.mlp.c_fc.weight', 'transformer.resblocks.11.mlp.c_proj.bias', 'transformer.resblocks.11.mlp.c_proj.weight', 'transformer.resblocks.12.attn.in_proj_bias', 'transformer.resblocks.12.attn.in_proj_weight', 'transformer.resblocks.12.attn.out_proj.bias', 'transformer.resblocks.12.attn.out_proj.weight', 'transformer.resblocks.12.ln_1.bias', 'transformer.resblocks.12.ln_1.weight', 'transformer.resblocks.12.ln_2.bias', 'transformer.resblocks.12.ln_2.weight', 'transformer.resblocks.12.mlp.c_fc.bias', 'transformer.resblocks.12.mlp.c_fc.weight', 'transformer.resblocks.12.mlp.c_proj.bias', 'transformer.resblocks.12.mlp.c_proj.weight', 'transformer.resblocks.13.attn.in_proj_bias', 'transformer.resblocks.13.attn.in_proj_weight', 'transformer.resblocks.13.attn.out_proj.bias', 'transformer.resblocks.13.attn.out_proj.weight', 'transformer.resblocks.13.ln_1.bias', 'transformer.resblocks.13.ln_1.weight', 'transformer.resblocks.13.ln_2.bias', 'transformer.resblocks.13.ln_2.weight', 'transformer.resblocks.13.mlp.c_fc.bias', 'transformer.resblocks.13.mlp.c_fc.weight', 'transformer.resblocks.13.mlp.c_proj.bias', 'transformer.resblocks.13.mlp.c_proj.weight', 'transformer.resblocks.14.attn.in_proj_bias', 'transformer.resblocks.14.attn.in_proj_weight', 'transformer.resblocks.14.attn.out_proj.bias', 'transformer.resblocks.14.attn.out_proj.weight', 'transformer.resblocks.14.ln_1.bias', 'transformer.resblocks.14.ln_1.weight', 'transformer.resblocks.14.ln_2.bias', 'transformer.resblocks.14.ln_2.weight', 'transformer.resblocks.14.mlp.c_fc.bias', 'transformer.resblocks.14.mlp.c_fc.weight', 'transformer.resblocks.14.mlp.c_proj.bias', 'transformer.resblocks.14.mlp.c_proj.weight', 'transformer.resblocks.15.attn.in_proj_bias', 'transformer.resblocks.15.attn.in_proj_weight', 'transformer.resblocks.15.attn.out_proj.bias', 'transformer.resblocks.15.attn.out_proj.weight', 'transformer.resblocks.15.ln_1.bias', 'transformer.resblocks.15.ln_1.weight', 'transformer.resblocks.15.ln_2.bias', 'transformer.resblocks.15.ln_2.weight', 'transformer.resblocks.15.mlp.c_fc.bias', 'transformer.resblocks.15.mlp.c_fc.weight', 'transformer.resblocks.15.mlp.c_proj.bias', 'transformer.resblocks.15.mlp.c_proj.weight', 'transformer.resblocks.16.attn.in_proj_bias', 'transformer.resblocks.16.attn.in_proj_weight', 'transformer.resblocks.16.attn.out_proj.bias', 'transformer.resblocks.16.attn.out_proj.weight', 'transformer.resblocks.16.ln_1.bias', 'transformer.resblocks.16.ln_1.weight', 'transformer.resblocks.16.ln_2.bias', 'transformer.resblocks.16.ln_2.weight', 'transformer.resblocks.16.mlp.c_fc.bias', 'transformer.resblocks.16.mlp.c_fc.weight', 'transformer.resblocks.16.mlp.c_proj.bias', 'transformer.resblocks.16.mlp.c_proj.weight', 'transformer.resblocks.17.attn.in_proj_bias', 'transformer.resblocks.17.attn.in_proj_weight', 'transformer.resblocks.17.attn.out_proj.bias', 'transformer.resblocks.17.attn.out_proj.weight', 'transformer.resblocks.17.ln_1.bias', 'transformer.resblocks.17.ln_1.weight', 'transformer.resblocks.17.ln_2.bias', 'transformer.resblocks.17.ln_2.weight', 'transformer.resblocks.17.mlp.c_fc.bias', 'transformer.resblocks.17.mlp.c_fc.weight', 'transformer.resblocks.17.mlp.c_proj.bias', 'transformer.resblocks.17.mlp.c_proj.weight', 'transformer.resblocks.18.attn.in_proj_bias', 'transformer.resblocks.18.attn.in_proj_weight', 'transformer.resblocks.18.attn.out_proj.bias', 'transformer.resblocks.18.attn.out_proj.weight', 'transformer.resblocks.18.ln_1.bias', 'transformer.resblocks.18.ln_1.weight', 'transformer.resblocks.18.ln_2.bias', 'transformer.resblocks.18.ln_2.weight', 'transformer.resblocks.18.mlp.c_fc.bias', 'transformer.resblocks.18.mlp.c_fc.weight', 'transformer.resblocks.18.mlp.c_proj.bias', 'transformer.resblocks.18.mlp.c_proj.weight', 'transformer.resblocks.19.attn.in_proj_bias', 'transformer.resblocks.19.attn.in_proj_weight', 'transformer.resblocks.19.attn.out_proj.bias', 'transformer.resblocks.19.attn.out_proj.weight', 'transformer.resblocks.19.ln_1.bias', 'transformer.resblocks.19.ln_1.weight', 'transformer.resblocks.19.ln_2.bias', 'transformer.resblocks.19.ln_2.weight', 'transformer.resblocks.19.mlp.c_fc.bias', 'transformer.resblocks.19.mlp.c_fc.weight', 'transformer.resblocks.19.mlp.c_proj.bias', 'transformer.resblocks.19.mlp.c_proj.weight', 'transformer.resblocks.2.attn.in_proj_bias', 'transformer.resblocks.2.attn.in_proj_weight', 'transformer.resblocks.2.attn.out_proj.bias', 'transformer.resblocks.2.attn.out_proj.weight', 'transformer.resblocks.2.ln_1.bias', 'transformer.resblocks.2.ln_1.weight', 'transformer.resblocks.2.ln_2.bias', 'transformer.resblocks.2.ln_2.weight', 'transformer.resblocks.2.mlp.c_fc.bias', 'transformer.resblocks.2.mlp.c_fc.weight', 'transformer.resblocks.2.mlp.c_proj.bias', 'transformer.resblocks.2.mlp.c_proj.weight', 'transformer.resblocks.20.attn.in_proj_bias', 'transformer.resblocks.20.attn.in_proj_weight', 'transformer.resblocks.20.attn.out_proj.bias', 'transformer.resblocks.20.attn.out_proj.weight', 'transformer.resblocks.20.ln_1.bias', 'transformer.resblocks.20.ln_1.weight', 'transformer.resblocks.20.ln_2.bias', 'transformer.resblocks.20.ln_2.weight', 'transformer.resblocks.20.mlp.c_fc.bias', 'transformer.resblocks.20.mlp.c_fc.weight', 'transformer.resblocks.20.mlp.c_proj.bias', 'transformer.resblocks.20.mlp.c_proj.weight', 'transformer.resblocks.21.attn.in_proj_bias', 'transformer.resblocks.21.attn.in_proj_weight', 'transformer.resblocks.21.attn.out_proj.bias', 'transformer.resblocks.21.attn.out_proj.weight', 'transformer.resblocks.21.ln_1.bias', 'transformer.resblocks.21.ln_1.weight', 'transformer.resblocks.21.ln_2.bias', 'transformer.resblocks.21.ln_2.weight', 'transformer.resblocks.21.mlp.c_fc.bias', 'transformer.resblocks.21.mlp.c_fc.weight', 'transformer.resblocks.21.mlp.c_proj.bias', 'transformer.resblocks.21.mlp.c_proj.weight', 'transformer.resblocks.22.attn.in_proj_bias', 'transformer.resblocks.22.attn.in_proj_weight', 'transformer.resblocks.22.attn.out_proj.bias', 'transformer.resblocks.22.attn.out_proj.weight', 'transformer.resblocks.22.ln_1.bias', 'transformer.resblocks.22.ln_1.weight', 'transformer.resblocks.22.ln_2.bias', 'transformer.resblocks.22.ln_2.weight', 'transformer.resblocks.22.mlp.c_fc.bias', 'transformer.resblocks.22.mlp.c_fc.weight', 'transformer.resblocks.22.mlp.c_proj.bias', 'transformer.resblocks.22.mlp.c_proj.weight', 'transformer.resblocks.23.attn.in_proj_bias', 'transformer.resblocks.23.attn.in_proj_weight', 'transformer.resblocks.23.attn.out_proj.bias', 'transformer.resblocks.23.attn.out_proj.weight', 'transformer.resblocks.23.ln_1.bias', 'transformer.resblocks.23.ln_1.weight', 'transformer.resblocks.23.ln_2.bias', 'transformer.resblocks.23.ln_2.weight', 'transformer.resblocks.23.mlp.c_fc.bias', 'transformer.resblocks.23.mlp.c_fc.weight', 'transformer.resblocks.23.mlp.c_proj.bias', 'transformer.resblocks.23.mlp.c_proj.weight', 'transformer.resblocks.3.attn.in_proj_bias', 'transformer.resblocks.3.attn.in_proj_weight', 'transformer.resblocks.3.attn.out_proj.bias', 'transformer.resblocks.3.attn.out_proj.weight', 'transformer.resblocks.3.ln_1.bias', 'transformer.resblocks.3.ln_1.weight', 'transformer.resblocks.3.ln_2.bias', 'transformer.resblocks.3.ln_2.weight', 'transformer.resblocks.3.mlp.c_fc.bias', 'transformer.resblocks.3.mlp.c_fc.weight', 'transformer.resblocks.3.mlp.c_proj.bias', 'transformer.resblocks.3.mlp.c_proj.weight', 'transformer.resblocks.4.attn.in_proj_bias', 'transformer.resblocks.4.attn.in_proj_weight', 'transformer.resblocks.4.attn.out_proj.bias', 'transformer.resblocks.4.attn.out_proj.weight', 'transformer.resblocks.4.ln_1.bias', 'transformer.resblocks.4.ln_1.weight', 'transformer.resblocks.4.ln_2.bias', 'transformer.resblocks.4.ln_2.weight', 'transformer.resblocks.4.mlp.c_fc.bias', 'transformer.resblocks.4.mlp.c_fc.weight', 'transformer.resblocks.4.mlp.c_proj.bias', 'transformer.resblocks.4.mlp.c_proj.weight', 'transformer.resblocks.5.attn.in_proj_bias', 'transformer.resblocks.5.attn.in_proj_weight', 'transformer.resblocks.5.attn.out_proj.bias', 'transformer.resblocks.5.attn.out_proj.weight', 'transformer.resblocks.5.ln_1.bias', 'transformer.resblocks.5.ln_1.weight', 'transformer.resblocks.5.ln_2.bias', 'transformer.resblocks.5.ln_2.weight', 'transformer.resblocks.5.mlp.c_fc.bias', 'transformer.resblocks.5.mlp.c_fc.weight', 'transformer.resblocks.5.mlp.c_proj.bias', 'transformer.resblocks.5.mlp.c_proj.weight', 'transformer.resblocks.6.attn.in_proj_bias', 'transformer.resblocks.6.attn.in_proj_weight', 'transformer.resblocks.6.attn.out_proj.bias', 'transformer.resblocks.6.attn.out_proj.weight', 'transformer.resblocks.6.ln_1.bias', 'transformer.resblocks.6.ln_1.weight', 'transformer.resblocks.6.ln_2.bias', 'transformer.resblocks.6.ln_2.weight', 'transformer.resblocks.6.mlp.c_fc.bias', 'transformer.resblocks.6.mlp.c_fc.weight', 'transformer.resblocks.6.mlp.c_proj.bias', 'transformer.resblocks.6.mlp.c_proj.weight', 'transformer.resblocks.7.attn.in_proj_bias', 'transformer.resblocks.7.attn.in_proj_weight', 'transformer.resblocks.7.attn.out_proj.bias', 'transformer.resblocks.7.attn.out_proj.weight', 'transformer.resblocks.7.ln_1.bias', 'transformer.resblocks.7.ln_1.weight', 'transformer.resblocks.7.ln_2.bias', 'transformer.resblocks.7.ln_2.weight', 'transformer.resblocks.7.mlp.c_fc.bias', 'transformer.resblocks.7.mlp.c_fc.weight', 'transformer.resblocks.7.mlp.c_proj.bias', 'transformer.resblocks.7.mlp.c_proj.weight', 'transformer.resblocks.8.attn.in_proj_bias', 'transformer.resblocks.8.attn.in_proj_weight', 'transformer.resblocks.8.attn.out_proj.bias', 'transformer.resblocks.8.attn.out_proj.weight', 'transformer.resblocks.8.ln_1.bias', 'transformer.resblocks.8.ln_1.weight', 'transformer.resblocks.8.ln_2.bias', 'transformer.resblocks.8.ln_2.weight', 'transformer.resblocks.8.mlp.c_fc.bias', 'transformer.resblocks.8.mlp.c_fc.weight', 'transformer.resblocks.8.mlp.c_proj.bias', 'transformer.resblocks.8.mlp.c_proj.weight', 'transformer.resblocks.9.attn.in_proj_bias', 'transformer.resblocks.9.attn.in_proj_weight', 'transformer.resblocks.9.attn.out_proj.bias', 'transformer.resblocks.9.attn.out_proj.weight', 'transformer.resblocks.9.ln_1.bias', 'transformer.resblocks.9.ln_1.weight', 'transformer.resblocks.9.ln_2.bias', 'transformer.resblocks.9.ln_2.weight', 'transformer.resblocks.9.mlp.c_fc.bias', 'transformer.resblocks.9.mlp.c_fc.weight', 'transformer.resblocks.9.mlp.c_proj.bias', 'transformer.resblocks.9.mlp.c_proj.weight', 'visual.attn_pool.attn.in_proj_bias', 'visual.attn_pool.attn.in_proj_weight', 'visual.attn_pool.attn.out_proj.bias', 'visual.attn_pool.attn.out_proj.weight', 'visual.attn_pool.layernorm.bias', 'visual.attn_pool.layernorm.weight', 'visual.attn_pool.mlp.c_fc.bias', 'visual.attn_pool.mlp.c_fc.weight', 'visual.attn_pool.mlp.c_proj.bias', 'visual.attn_pool.mlp.c_proj.weight', 'visual.attn_pool.probe', 'visual.class_embedding', 'visual.conv1.weight', 'visual.ln_post.bias', 'visual.ln_post.weight', 'visual.ln_pre.bias', 'visual.ln_pre.weight', 'visual.positional_embedding', 'visual.proj', 'visual.transformer.resblocks.0.attn.in_proj_bias', 'visual.transformer.resblocks.0.attn.in_proj_weight', 'visual.transformer.resblocks.0.attn.out_proj.bias', 'visual.transformer.resblocks.0.attn.out_proj.weight', 'visual.transformer.resblocks.0.ln_1.bias', 'visual.transformer.resblocks.0.ln_1.weight', 'visual.transformer.resblocks.0.ln_2.bias', 'visual.transformer.resblocks.0.ln_2.weight', 'visual.transformer.resblocks.0.mlp.c_fc.bias', 'visual.transformer.resblocks.0.mlp.c_fc.weight', 'visual.transformer.resblocks.0.mlp.c_proj.bias', 'visual.transformer.resblocks.0.mlp.c_proj.weight', 'visual.transformer.resblocks.1.attn.in_proj_bias', 'visual.transformer.resblocks.1.attn.in_proj_weight', 'visual.transformer.resblocks.1.attn.out_proj.bias', 'visual.transformer.resblocks.1.attn.out_proj.weight', 'visual.transformer.resblocks.1.ln_1.bias', 'visual.transformer.resblocks.1.ln_1.weight', 'visual.transformer.resblocks.1.ln_2.bias', 'visual.transformer.resblocks.1.ln_2.weight', 'visual.transformer.resblocks.1.mlp.c_fc.bias', 'visual.transformer.resblocks.1.mlp.c_fc.weight', 'visual.transformer.resblocks.1.mlp.c_proj.bias', 'visual.transformer.resblocks.1.mlp.c_proj.weight', 'visual.transformer.resblocks.10.attn.in_proj_bias', 'visual.transformer.resblocks.10.attn.in_proj_weight', 'visual.transformer.resblocks.10.attn.out_proj.bias', 'visual.transformer.resblocks.10.attn.out_proj.weight', 'visual.transformer.resblocks.10.ln_1.bias', 'visual.transformer.resblocks.10.ln_1.weight', 'visual.transformer.resblocks.10.ln_2.bias', 'visual.transformer.resblocks.10.ln_2.weight', 'visual.transformer.resblocks.10.mlp.c_fc.bias', 'visual.transformer.resblocks.10.mlp.c_fc.weight', 'visual.transformer.resblocks.10.mlp.c_proj.bias', 'visual.transformer.resblocks.10.mlp.c_proj.weight', 'visual.transformer.resblocks.11.attn.in_proj_bias', 'visual.transformer.resblocks.11.attn.in_proj_weight', 'visual.transformer.resblocks.11.attn.out_proj.bias', 'visual.transformer.resblocks.11.attn.out_proj.weight', 'visual.transformer.resblocks.11.ln_1.bias', 'visual.transformer.resblocks.11.ln_1.weight', 'visual.transformer.resblocks.11.ln_2.bias', 'visual.transformer.resblocks.11.ln_2.weight', 'visual.transformer.resblocks.11.mlp.c_fc.bias', 'visual.transformer.resblocks.11.mlp.c_fc.weight', 'visual.transformer.resblocks.11.mlp.c_proj.bias', 'visual.transformer.resblocks.11.mlp.c_proj.weight', 'visual.transformer.resblocks.12.attn.in_proj_bias', 'visual.transformer.resblocks.12.attn.in_proj_weight', 'visual.transformer.resblocks.12.attn.out_proj.bias', 'visual.transformer.resblocks.12.attn.out_proj.weight', 'visual.transformer.resblocks.12.ln_1.bias', 'visual.transformer.resblocks.12.ln_1.weight', 'visual.transformer.resblocks.12.ln_2.bias', 'visual.transformer.resblocks.12.ln_2.weight', 'visual.transformer.resblocks.12.mlp.c_fc.bias', 'visual.transformer.resblocks.12.mlp.c_fc.weight', 'visual.transformer.resblocks.12.mlp.c_proj.bias', 'visual.transformer.resblocks.12.mlp.c_proj.weight', 'visual.transformer.resblocks.13.attn.in_proj_bias', 'visual.transformer.resblocks.13.attn.in_proj_weight', 'visual.transformer.resblocks.13.attn.out_proj.bias', 'visual.transformer.resblocks.13.attn.out_proj.weight', 'visual.transformer.resblocks.13.ln_1.bias', 'visual.transformer.resblocks.13.ln_1.weight', 'visual.transformer.resblocks.13.ln_2.bias', 'visual.transformer.resblocks.13.ln_2.weight', 'visual.transformer.resblocks.13.mlp.c_fc.bias', 'visual.transformer.resblocks.13.mlp.c_fc.weight', 'visual.transformer.resblocks.13.mlp.c_proj.bias', 'visual.transformer.resblocks.13.mlp.c_proj.weight', 'visual.transformer.resblocks.14.attn.in_proj_bias', 'visual.transformer.resblocks.14.attn.in_proj_weight', 'visual.transformer.resblocks.14.attn.out_proj.bias', 'visual.transformer.resblocks.14.attn.out_proj.weight', 'visual.transformer.resblocks.14.ln_1.bias', 'visual.transformer.resblocks.14.ln_1.weight', 'visual.transformer.resblocks.14.ln_2.bias', 'visual.transformer.resblocks.14.ln_2.weight', 'visual.transformer.resblocks.14.mlp.c_fc.bias', 'visual.transformer.resblocks.14.mlp.c_fc.weight', 'visual.transformer.resblocks.14.mlp.c_proj.bias', 'visual.transformer.resblocks.14.mlp.c_proj.weight', 'visual.transformer.resblocks.15.attn.in_proj_bias', 'visual.transformer.resblocks.15.attn.in_proj_weight', 'visual.transformer.resblocks.15.attn.out_proj.bias', 'visual.transformer.resblocks.15.attn.out_proj.weight', 'visual.transformer.resblocks.15.ln_1.bias', 'visual.transformer.resblocks.15.ln_1.weight', 'visual.transformer.resblocks.15.ln_2.bias', 'visual.transformer.resblocks.15.ln_2.weight', 'visual.transformer.resblocks.15.mlp.c_fc.bias', 'visual.transformer.resblocks.15.mlp.c_fc.weight', 'visual.transformer.resblocks.15.mlp.c_proj.bias', 'visual.transformer.resblocks.15.mlp.c_proj.weight', 'visual.transformer.resblocks.16.attn.in_proj_bias', 'visual.transformer.resblocks.16.attn.in_proj_weight', 'visual.transformer.resblocks.16.attn.out_proj.bias', 'visual.transformer.resblocks.16.attn.out_proj.weight', 'visual.transformer.resblocks.16.ln_1.bias', 'visual.transformer.resblocks.16.ln_1.weight', 'visual.transformer.resblocks.16.ln_2.bias', 'visual.transformer.resblocks.16.ln_2.weight', 'visual.transformer.resblocks.16.mlp.c_fc.bias', 'visual.transformer.resblocks.16.mlp.c_fc.weight', 'visual.transformer.resblocks.16.mlp.c_proj.bias', 'visual.transformer.resblocks.16.mlp.c_proj.weight', 'visual.transformer.resblocks.17.attn.in_proj_bias', 'visual.transformer.resblocks.17.attn.in_proj_weight', 'visual.transformer.resblocks.17.attn.out_proj.bias', 'visual.transformer.resblocks.17.attn.out_proj.weight', 'visual.transformer.resblocks.17.ln_1.bias', 'visual.transformer.resblocks.17.ln_1.weight', 'visual.transformer.resblocks.17.ln_2.bias', 'visual.transformer.resblocks.17.ln_2.weight', 'visual.transformer.resblocks.17.mlp.c_fc.bias', 'visual.transformer.resblocks.17.mlp.c_fc.weight', 'visual.transformer.resblocks.17.mlp.c_proj.bias', 'visual.transformer.resblocks.17.mlp.c_proj.weight', 'visual.transformer.resblocks.18.attn.in_proj_bias', 'visual.transformer.resblocks.18.attn.in_proj_weight', 'visual.transformer.resblocks.18.attn.out_proj.bias', 'visual.transformer.resblocks.18.attn.out_proj.weight', 'visual.transformer.resblocks.18.ln_1.bias', 'visual.transformer.resblocks.18.ln_1.weight', 'visual.transformer.resblocks.18.ln_2.bias', 'visual.transformer.resblocks.18.ln_2.weight', 'visual.transformer.resblocks.18.mlp.c_fc.bias', 'visual.transformer.resblocks.18.mlp.c_fc.weight', 'visual.transformer.resblocks.18.mlp.c_proj.bias', 'visual.transformer.resblocks.18.mlp.c_proj.weight', 'visual.transformer.resblocks.19.attn.in_proj_bias', 'visual.transformer.resblocks.19.attn.in_proj_weight', 'visual.transformer.resblocks.19.attn.out_proj.bias', 'visual.transformer.resblocks.19.attn.out_proj.weight', 'visual.transformer.resblocks.19.ln_1.bias', 'visual.transformer.resblocks.19.ln_1.weight', 'visual.transformer.resblocks.19.ln_2.bias', 'visual.transformer.resblocks.19.ln_2.weight', 'visual.transformer.resblocks.19.mlp.c_fc.bias', 'visual.transformer.resblocks.19.mlp.c_fc.weight', 'visual.transformer.resblocks.19.mlp.c_proj.bias', 'visual.transformer.resblocks.19.mlp.c_proj.weight', 'visual.transformer.resblocks.2.attn.in_proj_bias', 'visual.transformer.resblocks.2.attn.in_proj_weight', 'visual.transformer.resblocks.2.attn.out_proj.bias', 'visual.transformer.resblocks.2.attn.out_proj.weight', 'visual.transformer.resblocks.2.ln_1.bias', 'visual.transformer.resblocks.2.ln_1.weight', 'visual.transformer.resblocks.2.ln_2.bias', 'visual.transformer.resblocks.2.ln_2.weight', 'visual.transformer.resblocks.2.mlp.c_fc.bias', 'visual.transformer.resblocks.2.mlp.c_fc.weight', 'visual.transformer.resblocks.2.mlp.c_proj.bias', 'visual.transformer.resblocks.2.mlp.c_proj.weight', 'visual.transformer.resblocks.20.attn.in_proj_bias', 'visual.transformer.resblocks.20.attn.in_proj_weight', 'visual.transformer.resblocks.20.attn.out_proj.bias', 'visual.transformer.resblocks.20.attn.out_proj.weight', 'visual.transformer.resblocks.20.ln_1.bias', 'visual.transformer.resblocks.20.ln_1.weight', 'visual.transformer.resblocks.20.ln_2.bias', 'visual.transformer.resblocks.20.ln_2.weight', 'visual.transformer.resblocks.20.mlp.c_fc.bias', 'visual.transformer.resblocks.20.mlp.c_fc.weight', 'visual.transformer.resblocks.20.mlp.c_proj.bias', 'visual.transformer.resblocks.20.mlp.c_proj.weight', 'visual.transformer.resblocks.21.attn.in_proj_bias', 'visual.transformer.resblocks.21.attn.in_proj_weight', 'visual.transformer.resblocks.21.attn.out_proj.bias', 'visual.transformer.resblocks.21.attn.out_proj.weight', 'visual.transformer.resblocks.21.ln_1.bias', 'visual.transformer.resblocks.21.ln_1.weight', 'visual.transformer.resblocks.21.ln_2.bias', 'visual.transformer.resblocks.21.ln_2.weight', 'visual.transformer.resblocks.21.mlp.c_fc.bias', 'visual.transformer.resblocks.21.mlp.c_fc.weight', 'visual.transformer.resblocks.21.mlp.c_proj.bias', 'visual.transformer.resblocks.21.mlp.c_proj.weight', 'visual.transformer.resblocks.22.attn.in_proj_bias', 'visual.transformer.resblocks.22.attn.in_proj_weight', 'visual.transformer.resblocks.22.attn.out_proj.bias', 'visual.transformer.resblocks.22.attn.out_proj.weight', 'visual.transformer.resblocks.22.ln_1.bias', 'visual.transformer.resblocks.22.ln_1.weight', 'visual.transformer.resblocks.22.ln_2.bias', 'visual.transformer.resblocks.22.ln_2.weight', 'visual.transformer.resblocks.22.mlp.c_fc.bias', 'visual.transformer.resblocks.22.mlp.c_fc.weight', 'visual.transformer.resblocks.22.mlp.c_proj.bias', 'visual.transformer.resblocks.22.mlp.c_proj.weight', 'visual.transformer.resblocks.23.attn.in_proj_bias', 'visual.transformer.resblocks.23.attn.in_proj_weight', 'visual.transformer.resblocks.23.attn.out_proj.bias', 'visual.transformer.resblocks.23.attn.out_proj.weight', 'visual.transformer.resblocks.23.ln_1.bias', 'visual.transformer.resblocks.23.ln_1.weight', 'visual.transformer.resblocks.23.ln_2.bias', 'visual.transformer.resblocks.23.ln_2.weight', 'visual.transformer.resblocks.23.mlp.c_fc.bias', 'visual.transformer.resblocks.23.mlp.c_fc.weight', 'visual.transformer.resblocks.23.mlp.c_proj.bias', 'visual.transformer.resblocks.23.mlp.c_proj.weight', 'visual.transformer.resblocks.3.attn.in_proj_bias', 'visual.transformer.resblocks.3.attn.in_proj_weight', 'visual.transformer.resblocks.3.attn.out_proj.bias', 'visual.transformer.resblocks.3.attn.out_proj.weight', 'visual.transformer.resblocks.3.ln_1.bias', 'visual.transformer.resblocks.3.ln_1.weight', 'visual.transformer.resblocks.3.ln_2.bias', 'visual.transformer.resblocks.3.ln_2.weight', 'visual.transformer.resblocks.3.mlp.c_fc.bias', 'visual.transformer.resblocks.3.mlp.c_fc.weight', 'visual.transformer.resblocks.3.mlp.c_proj.bias', 'visual.transformer.resblocks.3.mlp.c_proj.weight', 'visual.transformer.resblocks.4.attn.in_proj_bias', 'visual.transformer.resblocks.4.attn.in_proj_weight', 'visual.transformer.resblocks.4.attn.out_proj.bias', 'visual.transformer.resblocks.4.attn.out_proj.weight', 'visual.transformer.resblocks.4.ln_1.bias', 'visual.transformer.resblocks.4.ln_1.weight', 'visual.transformer.resblocks.4.ln_2.bias', 'visual.transformer.resblocks.4.ln_2.weight', 'visual.transformer.resblocks.4.mlp.c_fc.bias', 'visual.transformer.resblocks.4.mlp.c_fc.weight', 'visual.transformer.resblocks.4.mlp.c_proj.bias', 'visual.transformer.resblocks.4.mlp.c_proj.weight', 'visual.transformer.resblocks.5.attn.in_proj_bias', 'visual.transformer.resblocks.5.attn.in_proj_weight', 'visual.transformer.resblocks.5.attn.out_proj.bias', 'visual.transformer.resblocks.5.attn.out_proj.weight', 'visual.transformer.resblocks.5.ln_1.bias', 'visual.transformer.resblocks.5.ln_1.weight', 'visual.transformer.resblocks.5.ln_2.bias', 'visual.transformer.resblocks.5.ln_2.weight', 'visual.transformer.resblocks.5.mlp.c_fc.bias', 'visual.transformer.resblocks.5.mlp.c_fc.weight', 'visual.transformer.resblocks.5.mlp.c_proj.bias', 'visual.transformer.resblocks.5.mlp.c_proj.weight', 'visual.transformer.resblocks.6.attn.in_proj_bias', 'visual.transformer.resblocks.6.attn.in_proj_weight', 'visual.transformer.resblocks.6.attn.out_proj.bias', 'visual.transformer.resblocks.6.attn.out_proj.weight', 'visual.transformer.resblocks.6.ln_1.bias', 'visual.transformer.resblocks.6.ln_1.weight', 'visual.transformer.resblocks.6.ln_2.bias', 'visual.transformer.resblocks.6.ln_2.weight', 'visual.transformer.resblocks.6.mlp.c_fc.bias', 'visual.transformer.resblocks.6.mlp.c_fc.weight', 'visual.transformer.resblocks.6.mlp.c_proj.bias', 'visual.transformer.resblocks.6.mlp.c_proj.weight', 'visual.transformer.resblocks.7.attn.in_proj_bias', 'visual.transformer.resblocks.7.attn.in_proj_weight', 'visual.transformer.resblocks.7.attn.out_proj.bias', 'visual.transformer.resblocks.7.attn.out_proj.weight', 'visual.transformer.resblocks.7.ln_1.bias', 'visual.transformer.resblocks.7.ln_1.weight', 'visual.transformer.resblocks.7.ln_2.bias', 'visual.transformer.resblocks.7.ln_2.weight', 'visual.transformer.resblocks.7.mlp.c_fc.bias', 'visual.transformer.resblocks.7.mlp.c_fc.weight', 'visual.transformer.resblocks.7.mlp.c_proj.bias', 'visual.transformer.resblocks.7.mlp.c_proj.weight', 'visual.transformer.resblocks.8.attn.in_proj_bias', 'visual.transformer.resblocks.8.attn.in_proj_weight', 'visual.transformer.resblocks.8.attn.out_proj.bias', 'visual.transformer.resblocks.8.attn.out_proj.weight', 'visual.transformer.resblocks.8.ln_1.bias', 'visual.transformer.resblocks.8.ln_1.weight', 'visual.transformer.resblocks.8.ln_2.bias', 'visual.transformer.resblocks.8.ln_2.weight', 'visual.transformer.resblocks.8.mlp.c_fc.bias', 'visual.transformer.resblocks.8.mlp.c_fc.weight', 'visual.transformer.resblocks.8.mlp.c_proj.bias', 'visual.transformer.resblocks.8.mlp.c_proj.weight', 'visual.transformer.resblocks.9.attn.in_proj_bias', 'visual.transformer.resblocks.9.attn.in_proj_weight', 'visual.transformer.resblocks.9.attn.out_proj.bias', 'visual.transformer.resblocks.9.attn.out_proj.weight', 'visual.transformer.resblocks.9.ln_1.bias', 'visual.transformer.resblocks.9.ln_1.weight', 'visual.transformer.resblocks.9.ln_2.bias', 'visual.transformer.resblocks.9.ln_2.weight', 'visual.transformer.resblocks.9.mlp.c_fc.bias', 'visual.transformer.resblocks.9.mlp.c_fc.weight', 'visual.transformer.resblocks.9.mlp.c_proj.bias', 'visual.transformer.resblocks.9.mlp.c_proj.weight'])
2. model build: 1.51 sec
3. load_state_dict: 0.24 sec
Missing: 0 Unexpected: 0
Missing sample: []
Unexpected sample: []
4. to device: 0.34 sec
Done
Image size: 336
Context length: 32
Model device: cuda:0
Model dtype: torch.float32
CLIP(
  (token_embedding): Embedding(49408, 1024)
  (transformer): Transformer(
    (resblocks): ModuleList(
      (0-23): 24 x ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (ls_1): Identity()
        (ls_2): Identity()
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (drop_path1): Identity()
        (drop_path2): Identity()
        (mlp): Sequential(
          (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
  )
  (ln_final): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-23): 24 x ResidualAttentionBlock(
          (attn): SelfAttention(
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (ls_1): Identity()
          (ls_2): Identity()
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (drop_path1): Identity()
          (drop_path2): Identity()
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
        )
      )
    )
    (attn_pool): AttentionPooling(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
        (gelu): GELU(approximate='none')
        (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
      )
    )
  )
)

Datasets#

Load gylang reid set

df = pd.read_csv(METADATA_CSV)

print(df.head())
print(df.columns.tolist())
print(df["final_split"].value_counts(dropna=False))

required_cols = ["path", "vid", "camid", "final_split"]
missing = [c for c in required_cols if c not in df.columns]
assert not missing, f"Missing columns: {missing}"

df = df[required_cols].copy()
df["path"] = df["path"].astype(str)
df["vid"] = df["vid"].astype(int)
df["camid"] = df["camid"].astype(int)

query_df = df[df["final_split"] == "query"].copy()
gallery_df = df[df["final_split"] == "gallery"].copy()

print("query:", len(query_df))
print("gallery:", len(gallery_df))
print("query vids:", query_df["vid"].nunique())
print("gallery vids:", gallery_df["vid"].nunique())
                                                path  \
0  /home/user/data/datasets/VeRi_Self/reid/cam_0_...   
1  /home/user/data/datasets/VeRi_Self/reid/cam_0_...   
2  /home/user/data/datasets/VeRi_Self/reid/cam_0_...   
3  /home/user/data/datasets/VeRi_Self/reid/cam_0_...   
4  /home/user/data/datasets/VeRi_Self/reid/cam_0_...   

                                           rel_path  \
0  cam_0_001/rank_01_frame_00000102_score_0.898.jpg   
1  cam_0_001/rank_02_frame_00000070_score_0.811.jpg   
2  cam_0_001/rank_03_frame_00000040_score_0.743.jpg   
3  cam_0_002/rank_01_frame_00000206_score_0.898.jpg   
4  cam_0_002/rank_02_frame_00000191_score_0.824.jpg   

                                file_name     folder  vid  camid  rank  \
0  rank_01_frame_00000102_score_0.898.jpg  cam_0_001    1      0     1   
1  rank_02_frame_00000070_score_0.811.jpg  cam_0_001    1      0     2   
2  rank_03_frame_00000040_score_0.743.jpg  cam_0_001    1      0     3   
3  rank_01_frame_00000206_score_0.898.jpg  cam_0_002    2      0     1   
4  rank_02_frame_00000191_score_0.824.jpg  cam_0_002    2      0     2   

   frame_idx  score  ext split_initial final_split  
0        102  0.898  jpg         query       query  
1         70  0.811  jpg       gallery     gallery  
2         40  0.743  jpg       gallery     gallery  
3        206  0.898  jpg         query       query  
4        191  0.824  jpg       gallery     gallery  
['path', 'rel_path', 'file_name', 'folder', 'vid', 'camid', 'rank', 'frame_idx', 'score', 'ext', 'split_initial', 'final_split']
final_split
gallery    58
query      29
Name: count, dtype: int64
query: 29
gallery: 58
query vids: 15
gallery vids: 15

load distractors 100k

IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}

distractor_paths = []

for p in DISTRACTOR_ROOT.rglob("*"):
    if p.is_file() and p.suffix.lower() in IMAGE_EXTS:
        distractor_paths.append(str(p.resolve()))

distractor_paths = sorted(distractor_paths)

print("num distractors:", len(distractor_paths))
print(distractor_paths[:5])
num distractors: 100000
['/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/000a3ff7-d31c-4fe1-bdf0-020b7d7a93a4.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/0010ded2-9ec0-4e78-b7d4-d64f4c1c9303.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/00148612-07d9-4e7d-8552-04be8af22435.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/001497dd-6933-4f52-9291-fa32283ac95f.jpg', '/home/user/data/datasets/VeRi_Self/PKU_Vehicle/0/001a7312-d74e-46bf-b5f6-b91e49558feb.jpg']

merge dataframes

distractor_df = pd.DataFrame({
    "path": distractor_paths,
    "vid": -1,
    "camid": -1,
    "final_split": "gallery_distractor",
})

gallery_plus_df = pd.concat(
    [
        gallery_df[["path", "vid", "camid", "final_split"]],
        distractor_df[["path", "vid", "camid", "final_split"]],
    ],
    ignore_index=True,
)

gallery_plus_df = gallery_plus_df.reset_index(drop=True)
query_df = query_df.reset_index(drop=True)

print("gallery original:", len(gallery_df))
print("gallery + distractor:", len(gallery_plus_df))
print(gallery_plus_df["final_split"].value_counts())
gallery original: 58
gallery + distractor: 100058
final_split
gallery_distractor    100000
gallery                   58
Name: count, dtype: int64

create unique imageid

query_df["image_id"] = ["query_" + str(i).zfill(8) for i in range(len(query_df))]
gallery_plus_df["image_id"] = ["gallery_" + str(i).zfill(8) for i in range(len(gallery_plus_df))]

query_df.head()
path vid camid final_split image_id
0 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 1 0 query query_00000000
1 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 2 0 query query_00000001
2 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 3 0 query query_00000002
3 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 4 0 query query_00000003
4 /home/user/data/datasets/VeRi_Self/reid/cam_0_... 5 0 query query_00000004

Dataclass for both CLIP and ReID#

class ImageDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df.reset_index(drop=True)
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row["path"]).convert("RGB")
        img = self.preprocess(img)

        return {
            "image": img,
            "vid": int(row["vid"]),
            "camid": int(row["camid"]),
            "path": row["path"],
            "image_id": row["image_id"],
            "final_split": row["final_split"],
        }

def collate_fn(batch):
    images = torch.stack([x["image"] for x in batch])

    return {
        "images": images,
        "vids": np.array([x["vid"] for x in batch], dtype=np.int64),
        "camids": np.array([x["camid"] for x in batch], dtype=np.int64),
        "paths": [x["path"] for x in batch],
        "image_ids": [x["image_id"] for x in batch],
        "final_splits": [x["final_split"] for x in batch],
    }

DataLoaders#

BATCH_SIZE = 128
NUM_WORKERS = 8

query_dataset = ImageDataset(query_df, preprocess)
gallery_dataset = ImageDataset(gallery_plus_df, preprocess)

query_loader = DataLoader(
    query_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    collate_fn=collate_fn,
)

gallery_loader = DataLoader(
    gallery_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    collate_fn=collate_fn,
)

Embedding#

def l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return x / x.norm(dim=-1, keepdim=True).clamp(min=eps)
@torch.inference_mode()
def encode_loader(loader, model, device=DEVICE, dtype=torch.bfloat16):
    all_embeddings = []
    rows = []

    start = time.time()

    for batch in tqdm(loader):
        images = batch["images"].to(device, non_blocking=True)

        with torch.autocast(
            device_type="cuda",
            dtype=dtype,
            enabled=(device == "cuda"),
        ):
            features = model.encode_image(images)

        features = features.float()
        features = l2_normalize(features)

        emb = features.cpu().numpy().astype("float32")
        all_embeddings.append(emb)

        batch_size = emb.shape[0]

        for i in range(batch_size):
            rows.append({
                "image_id": batch["image_ids"][i],
                "path": batch["paths"][i],
                "vid": int(batch["vids"][i]),
                "camid": int(batch["camids"][i]),
                "final_split": batch["final_splits"][i],
            })

    embeddings = np.vstack(all_embeddings)

    elapsed = time.time() - start
    print("embeddings:", embeddings.shape)
    print("elapsed min:", round(elapsed / 60, 2))
    print("images/sec:", round(len(rows) / elapsed, 2))

    meta = pd.DataFrame(rows)

    return embeddings, meta

Text Encoding#

@torch.inference_mode()
def encode_texts(texts, model, tokenizer, device=DEVICE, dtype=torch.bfloat16):
    if isinstance(texts, str):
        texts = [texts]

    text_tokens = tokenizer(texts).to(device)

    with torch.autocast(
        device_type="cuda",
        dtype=dtype,
        enabled=(device == "cuda"),
    ):
        text_features = model.encode_text(text_tokens)

    text_features = text_features.float()
    text_features = l2_normalize(text_features)

    return text_features.cpu().numpy().astype("float32")

Text to Image search

def text_to_image_search(query, top_k=5):
    text_emb = encode_texts(query, model, tokenizer)[0]

    scores = gallery_embeddings @ text_emb

    top_indices = np.argsort(scores)[::-1][:top_k]
    top_scores = scores[top_indices]

    mean = scores.mean()
    std = scores.std() + 1e-12

    results = gallery_meta.iloc[top_indices].copy()
    results["query"] = query
    results["rank"] = np.arange(1, len(results) + 1)
    results["cosine_score"] = top_scores
    results["z_score"] = (top_scores - mean) / std
    results["gap_from_top1"] = top_scores[0] - top_scores

    return results.reset_index(drop=True), scores

Show results

results, scores = text_to_image_search("white car with black bonnet", top_k=5)
# display(results)
show_search_results(results, title="white car with black bonnet")
../_images/224fe979323782efef8d44ac233fac51e0ebc8cb27d0455dced2defafed4700a.png

Evals#

review_queries = [
    "white car with sunroof",
    "white car with black roof",
    "SUV with roof rails",
    "truck with open cargo bed",
    "box truck with cargo container",
    "white van with yellow stripes",
    "pink sports car",
    "green taxi with roof sign",
    "white car with flames painted on the side",
    "ambulance with red cross",
    "vehicle with zebra stripes",
    "car carrying bicycle on roof",
]
queries_df = pd.DataFrame({
    "query_id": [f"q{i+1:03d}" for i in range(len(review_queries))],
    "query": review_queries,
})
all_review_results = []

for _, row in tqdm(queries_df.iterrows(), total=len(queries_df)):
    query_id = row["query_id"]
    query = row["query"]

    results, _ = text_to_image_search(query, top_k=5)
    results["query_id"] = query_id

    all_review_results.append(results)

review_results = pd.concat(all_review_results, ignore_index=True)

review_results["human_label"] = -1
review_results["human_notes"] = ""

review_results = review_results[
    [
        "query_id",
        "query",
        "rank",
        "image_id",
        "path",
        "vid",
        "camid",
        "final_split",
        "cosine_score",
        "z_score",
        "gap_from_top1",
        "human_label",
        "human_notes",
    ]
]

review_results.to_csv(OUTPUT_DIR / "text_review_top5_unlabeled.csv", index=False)

review_results.head()
100%|██████████| 12/12 [00:00<00:00, 57.15it/s]
query_id query rank image_id path vid camid final_split cosine_score z_score gap_from_top1 human_label human_notes
0 q001 white car with sunroof 1 gallery_00032241 /home/user/data/datasets/VeRi_Self/PKU_Vehicle... -1 -1 gallery_distractor 0.286042 2.644217 0.000000 -1
1 q001 white car with sunroof 2 gallery_00045133 /home/user/data/datasets/VeRi_Self/PKU_Vehicle... -1 -1 gallery_distractor 0.284771 2.609696 0.001271 -1
2 q001 white car with sunroof 3 gallery_00035865 /home/user/data/datasets/VeRi_Self/PKU_Vehicle... -1 -1 gallery_distractor 0.284661 2.606714 0.001381 -1
3 q001 white car with sunroof 4 gallery_00040658 /home/user/data/datasets/VeRi_Self/PKU_Vehicle... -1 -1 gallery_distractor 0.284134 2.592382 0.001909 -1
4 q001 white car with sunroof 5 gallery_00008036 /home/user/data/datasets/VeRi_Self/PKU_Vehicle... -1 -1 gallery_distractor 0.284064 2.590488 0.001978 -1

labelling#

Was planning to do manual labelling for threshold calibration… say for each output rank, feed in below labels

  • 2 = correct / strong match

  • 1 = partial / acceptable

  • 0 = wrong

  • -1 = not reviewed

so for query 0 we could say [2, 2, 1, 0, 0] that would mean

  • Rank 1 → 2 = correct / strong match

  • Rank 2 → 2 = correct / strong match

  • Rank 3 → 1 = partial / acceptable

  • Rank 4 → 0 = wrong

  • Rank 5 → 0 = wrong

and then write this back into the metadata file under human_label…

No time, will skip this for now

review_results_labeled = review_results.copy()
def next_unreviewed_queries(n=5):
    unreviewed = review_results_labeled[
        review_results_labeled["human_label"] == -1
    ]["query_id"].unique()

    return list(unreviewed[:n])
qids = next_unreviewed_queries(12)
show_queries_for_labeling(qids)
../_images/5ac72c37825d09fbfac636af875d346edbf115948605e193ea48116867d5244f.png

Image to Image Search (simulate ReID)#

using a CLIP model for Image to Image search, simulating ReID

def image_to_image_search(query_index, top_k=10, remove_same_cam=True):
    q_emb = query_embeddings[query_index]
    q_row = query_meta.iloc[query_index]

    scores = gallery_embeddings @ q_emb

    candidate_df = gallery_meta.copy()
    candidate_df["score"] = scores

    # Optional ReID protocol: remove same vid + same cam junk
    if remove_same_cam:
        junk_mask = (
            (candidate_df["vid"] == q_row["vid"]) &
            (candidate_df["camid"] == q_row["camid"])
        )
        candidate_df = candidate_df[~junk_mask].copy()

    candidate_df = candidate_df.sort_values("score", ascending=False).head(top_k)

    candidate_df["rank"] = np.arange(1, len(candidate_df) + 1)

    candidate_df["query_path"] = q_row["path"]
    candidate_df["query_vid"] = q_row["vid"]
    candidate_df["query_camid"] = q_row["camid"]

    candidate_df["is_correct_vid"] = candidate_df["vid"] == q_row["vid"]

    return candidate_df.reset_index(drop=True)
reid_results = image_to_image_search(0, top_k=10)
reid_results[["rank", "score", "vid", "camid", "is_correct_vid", "path"]]
rank score vid camid is_correct_vid path
0 1 0.842100 1 1 True /home/user/data/datasets/VeRi_Self/reid/cam_1_...
1 2 0.842017 1 1 True /home/user/data/datasets/VeRi_Self/reid/cam_1_...
2 3 0.832953 8 1 False /home/user/data/datasets/VeRi_Self/reid/cam_1_...
3 4 0.811326 8 1 False /home/user/data/datasets/VeRi_Self/reid/cam_1_...
4 5 0.781126 8 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
5 6 0.778882 3 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
6 7 0.777549 3 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
7 8 0.764441 6 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
8 9 0.753039 4 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
9 10 0.740953 8 0 False /home/user/data/datasets/VeRi_Self/reid/cam_0_...
query_index = 0
reid_results = image_to_image_search(query_index, top_k=5)
show_image_to_image_results(query_index, reid_results)
../_images/fd8b9cbf6a6cbdf62b981b3cbd9bd841a79dc6c8df632d8ea9cf1c29f29a68e8.png
def evaluate_reid_rank_k(max_queries=None, top_k=10):
    n_queries = len(query_meta) if max_queries is None else min(max_queries, len(query_meta))

    rank1_hits = 0
    rank5_hits = 0
    rank10_hits = 0
    valid_queries = 0

    for qi in tqdm(range(n_queries)):
        q_vid = query_meta.iloc[qi]["vid"]

        results = image_to_image_search(qi, top_k=top_k, remove_same_cam=True)

        if len(results) == 0:
            continue

        valid_queries += 1

        correct = results["vid"].values == q_vid

        rank1_hits += int(correct[:1].any())
        rank5_hits += int(correct[:5].any())
        rank10_hits += int(correct[:10].any())

    return {
        "valid_queries": valid_queries,
        "rank1": rank1_hits / valid_queries,
        "rank5": rank5_hits / valid_queries,
        "rank10": rank10_hits / valid_queries,
    }
evaluate_reid_rank_k(max_queries=100, top_k=10)
100%|██████████| 29/29 [00:00<00:00, 48.03it/s]
{'valid_queries': 29,
 'rank1': 0.896551724137931,
 'rank5': 0.9655172413793104,
 'rank10': 1.0}
for query_index in range(29):
    reid_results = image_to_image_search(query_index, top_k=5)
    show_image_to_image_results(query_index, reid_results)
../_images/fd8b9cbf6a6cbdf62b981b3cbd9bd841a79dc6c8df632d8ea9cf1c29f29a68e8.png ../_images/9733659f3ef67c44293b5453d96b9c487c0cfea6c963fb1b76562554b26cec9c.png ../_images/a4e7357b133d6fdd54a8ac3d4459be4530f257a3c42a297819f1e958081689e8.png ../_images/369c9aa14b1eedf25511f2e40bc3766b4ec3e22b27c978d07c1066a725b03195.png ../_images/44785ea67ebd3ea65fa260e456eee60d0ea345d218aa4617c89293df03e060a4.png ../_images/0833863a403778412ad5552f7036e08729104e49b1546c1aed8f81e3657ac09f.png ../_images/2e9806fdd02be197702f60c59f250de0f386c926e6652769a9f76d3cb23b53fb.png ../_images/46a0ce9a21aa98902b604a5cb30c11c967dc9bd4f3c184446906e9737ee40ceb.png ../_images/1be9dd9f2985f0cd42fb77a7deb1892099c19876d6c3626f93ed67f3af0adcce.png ../_images/9b1413aa19e6ec9cb043de62f49e11b6246d454905dd3f06a34eeac4f688cb32.png ../_images/22fc4b37592d26efcbd386cdb81c39b2259238212a8e85085e44793d9363ab4d.png ../_images/ea28481a2340d81e75f5eb4221ba30d406f146f89783914dbdfbeeb3fac246cb.png ../_images/b102f968d7e7d6ac150da02cb494774608604ec25d03fb2e6e71824d27dfb48f.png ../_images/b5b63db4986cacb78cad4a5d87e5f6ec713038802baa88b5034e8d3c0477ae9b.png ../_images/04b4848875f3a6c1ee79f40f679f12f8acba9c5a4a02cfda58b53d046a6c34c1.png ../_images/eb10297dafef25ae579ed5f2a1d1e58f7937f4e746625e4b0e7c87ee76e1581f.png ../_images/f62319386fea41ba8002bef7713ecb364cda3b0954e223cb09b66623eb6f348d.png ../_images/6e92919a2105b8ceb234dc00b875400cc8d17b3909c73ad50ef0414b7945b0c5.png ../_images/2fbf85b2e1cdfe2a6b1bfccc8c0d7327ca6aad01784e10eb843ac4089bb3c03f.png ../_images/0778b54a9a6e84ac2f80eba97dc5b418368767a000b2ff9b50d3d28b892ca571.png ../_images/7da44a476ef7446d24ae965ea5809c501b54e1b5b6d8c172bbb9c5a07ef8bc17.png ../_images/9fce512ee7b8fdbfdb9664a89b006d935c0fa905bc9b1b46c93480ce66b5d900.png ../_images/5ee8e00951be01261d6fcc863ec5c906d8b5d157c1b36c3dc6030c4fd45d9921.png ../_images/e8d032e583037d4b65935beb421af2a574e6da00e1e8df3217032661e1c3bb0d.png ../_images/364448319be6b0b5f7b59b6c9adce26b451987b2520efdb8988b77d496719f21.png ../_images/45d72ceb55fa52a0b986cdcf4f74f1dc4aeb94ece5d2a7d4de13f51092648f35.png ../_images/acc9e8e08504b0eb732c83e69f7ec632db51d4fbfc072a60f32b50b3e92fc15e.png ../_images/adaea4db3ee6d0d270f89612d1bbc5d83bb89f3e911fd52f3ca156f9098343fc.png ../_images/66267eb53b20c0b9d1e88842f9e3ec69f77df8c95bd48ff51f7d97933df37fa8.png