Exploring PE and PLM: on 5090#

Perception Models: Powerful Models for Image and Video Perception

PLM#

Perception Language Model (PLM) is a state-of-the-art, fully open and reproducible MLLM for transparent research in image and video understanding.

Hugging Face Collection Paper Paper ModelLicense

Login in HF hub#

from huggingface_hub.hf_api import HfFolder
HfFolder.save_token('ZZZZ')

Use nightly built of torch for 5090 in response to below warning:

⚠️⚠️ WARNING ⚠️⚠️ NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation. Supported sm_50 sm_60 sm_70 sm_75 sm_80 sm_86 sm_90.

pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 xformers --index-url https://download.pytorch.org/whl/cu128
conda install ffmpeg -c conda-forge
pip install torchcodec==0.4.0
pip install -e .  
import torch
print(torch.__version__)
print(torch.cuda.get_arch_list())
print(torch.randn(1).cuda())      # Should match 12.1+
2.7.1+cu128
['sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
tensor([1.1256], device='cuda:0')
import os
import sys
import re 

# Add the project root to Python path
sys.path.append('/home/n03an/code/references/perception_models')

import torch
from PIL import Image, ImageDraw

import time
from IPython.display import HTML
from base64 import b64encode
import textwrap
import requests
import urllib.request

from core.args import dataclass_from_dict
from core.transforms.image_transform import get_image_transform
from core.transforms.video_transform import get_video_transform
from apps.plm.generate import PackedCausalTransformerGeneratorArgs, PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer

Load PLM#

# ckpt = "facebook/Perception-LM-1B"
ckpt = "facebook/Perception-LM-3B" 
# ckpt = "facebook/Perception-LM-8B" 
model, tokenizer, config = load_consolidated_model_and_tokenizer(ckpt)
Downloading facebook/Perception-LM-3B from Hugging Face Hub...
Fetching 19 files: 100%|██████████| 19/19 [00:00<00:00, 184471.70it/s]
INFO:apps.plm.tokenizer:Reloaded tiktoken model from /home/n03an/.cache/huggingface/hub/models--facebook--Perception-LM-3B/snapshots/31e3665b544e2dbac215e80200923139fb398975/original/tokenizer.model
INFO:apps.plm.tokenizer:#words: 128256 - BOS ID: 128000 - EOS ID: 128001
Downloaded to: /home/n03an/.cache/huggingface/hub/models--facebook--Perception-LM-3B/snapshots/31e3665b544e2dbac215e80200923139fb398975/original
INFO:apps.plm.transformer:Initializing PE_VisionTransformer with args: {'image_size': 448, 'patch_size': 14, 'width': 1024, 'layers': 23, 'heads': 16, 'use_cls_token': True, 'use_abs_posemb': True, 'mlp_ratio': 4.0, 'ls_init_value': 0.1, 'drop_path': 0.1, 'use_ln_post': False, 'pool_type': 'none'}
def generate(
    media_path,
    question="Describe the image in details.",
    media_type="image",
    number_of_frames=4,
    number_of_tiles=1,
    temperature=0.0,
    top_p=None,
    top_k=None,
    return_text=False,
):
    prompts = []
    if media_type == "image":
        transform = get_image_transform(
            vision_input_type=(
                "vanilla" if number_of_tiles == 1 else config.data.vision_input_type
            ),
            image_res=model.vision_model.image_size,
            max_num_tiles=number_of_tiles,
        )
        if isinstance(media_path, str):
            image = Image.open(media_path).convert("RGB")
        else:
            image = media_path
        image, _ = transform(image)
        prompts.append((question, image))
    elif media_type == "video":
        transform = get_video_transform(
            image_res=model.vision_model.image_size,
        )
        video_info = (media_path, number_of_frames, None, None, None)
        frames, _ = transform(video_info)
        prompts.append((question, frames))
    else:
        raise NotImplementedError(
            f"The provided generate function only supports image and video."
        )
    # Create generator
    gen_cfg = dataclass_from_dict(
        PackedCausalTransformerGeneratorArgs,
        {"temperature": temperature, "top_p": top_p, "top_k": top_k},
        strict=False,
    )
    generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
    # Run generation
    start_time = time.time()
    generation, loglikelihood, greedy = generator.generate(prompts)
    end_time = time.time()
    if return_text:
        print(generation[0])
        return generation[0]
    
    for i, gen in enumerate(generation):
        # Calculate tokens per second
        total_tokens = sum(
            len(tokenizer.encode(gen, False, False)) for gen in generation
        )
        tokens_per_second = total_tokens / (end_time - start_time)
        print("=================================================")
        print(textwrap.fill(gen, width=75))
        print(f"Tokens per second: {tokens_per_second:.2f}")
        print("=================================================")

def extract_all_bounding_boxes(text: str) -> list[list[str]]:
    """Extracts any list of arbitrary length from a string."""
    pattern = r"\[\s*([^\[\]]*?)\s*\]"
    extracted_lists = [
        [num.strip() for num in match.split(",")] for match in re.findall(pattern, text)
    ]
    return extracted_lists

def rescale_2d_bboxes(bboxes, img_w, img_h, box_format="000", verbose=True):
    w, h = img_w, img_h
    rescaled_bboxes = []
    for bbox in bboxes:
        try:
            if box_format == "000":
                bbox = [float("0." + b.strip()) for b in bbox]
            elif box_format == "standard":
                bbox = [float(b.strip()) for b in bbox]
            else:
                # we don't know the format. try both
                try:
                    bbox = [float("0." + b.strip()) for b in bbox]
                except:
                    bbox = [float(b.strip()) for b in bbox]

            x1, y1, x2, y2 = bbox
            bbox = [x1 * w, y1 * h, x2 * w, y2 * h]

            rescaled_bboxes.append(bbox)
        except Exception as e:
            if verbose:
                print("[rescale_2d_bboxes]:", e, bbox, flush=True)
            pass
    return rescaled_bboxes

def postprocess_grounding(x: str, img_w: int, img_h: int) -> list[float]:
    bboxes = extract_all_bounding_boxes(x)
    bboxes = rescale_2d_bboxes(bboxes, img_w, img_h)
    if len(bboxes) > 0:
        box = bboxes[0]
    else:
        # no box found.
        box = [0, 0, img_w, img_h]
    return box


def generate_grounding(media_path: str, question: str, number_of_tiles: int):
    image = Image.open(media_path)
    w, h = image.size

    print("Generating...")
    output = generate(media_path=media_path, question=question, number_of_tiles=number_of_tiles, media_type="image", return_text=True)

    box = postprocess_grounding(output, w, h)

    draw = ImageDraw.Draw(image)
    try:
        x_min, y_min, x_max, y_max = box

        if x_min < x_max and y_min < y_max:
            # Draw the bounding box
            draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
    except Exception as e:
        print(f"Error drawing bounding box: {e}")
    return image

Run inference for image grounding task#

question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

image_url = "http://farm3.staticflickr.com/2453/3867429392_ed6f3d337a_z.jpg"
image_path = "3867429392_ed6f3d337a_z.jpg"

urllib.request.urlretrieve(image_url, image_path)

description = "white fire hydrant in the back"
question = question_template.format(caption=description)

img = Image.open(image_path)
display(img)

img_drawn = generate_grounding(media_path=image_path, question=question, number_of_tiles=36)
display(img_drawn)
../../_images/e5c8b7aed625cdd96a273149e60bef80d3757df23366117baf3bfddef6901cbd.png
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[040,490,109,570]
../../_images/39242c8f3d0c08bdc625edbd15c4c98e05f36958fc5f1b6c5c3bf8567a56ba5e.png

Region captioning (bounding box as text)#

question = "Please describe the region ({bbox}) in details.\nThe region is in the format of [x1, y1, x2, y2]."
question = question.format(bbox="[040,482,112,576]")
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
fire hydrant in the background
Tokens per second: 1.29
=================================================

Region captioning (bounding box as drawing)#

question = "Please describe the region inside the rec rectangle."
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
[040,482,109,572]
Tokens per second: 7.43
=================================================

Testing on samples from VIRAT Dataset#

question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

image_path = "condo_parking.png"

description = "person walking with umbrella"
question = question_template.format(caption=description)

img = Image.open(image_path)
display(img)

img_drawn = generate_grounding(media_path=image_path, question=question, number_of_tiles=36)
display(img_drawn)
../../_images/97769185ce013910003e530aeaf57e388eacfacc3bf4705921799240774ba059.png
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[231,720,255,833]
../../_images/e172d8ce8adadbab6188c9db3ad27e67180f9a6438071d58c91f242bbd2fcaa5.png

Injecting bounding boxes from YOLOE#

from ultralytics import YOLOE

yoloe_model = YOLOE("yoloe-11l-seg-pf.pt")
import matplotlib.pyplot as plt
from IPython.display import display, Image

def detect_objects(image_path, model=yoloe_model):
    results = model(image_path)
    plt.figure(figsize=(12, 8))
    plt.imshow(results[0].plot()[:, :, ::-1])
    plt.axis("off")
    plt.show()
    return results
results = detect_objects("condo_parking.png")
# print(results)
image 1/1 /home/n03an/code/references/perception_models/apps/plm/notebook_demos/condo_parking.png: 384x640 2 barrels, 20 cars, 1 emergency vehicle, 1 information desk, 1 recycling bin, 1 shipping container, 6 suvs, 1 tow truck, 2 trailer trucks, 4 vans, 1 waste container, 39.9ms
Speed: 1.9ms preprocess, 39.9ms inference, 9.8ms postprocess per image at shape (1, 3, 384, 640)
../../_images/5dc81ee83baaed99c8bfaafa0cc047ff3f750203e6adda445bdf4eae306e2fbf.png

Cleanup image to keep only the bounding boxes

from PIL import Image, ImageDraw

img_drawn = Image.open("condo_parking.png")
draw = ImageDraw.Draw(img_drawn)

for box in results[0].boxes:
    x1, y1, x2, y2 = box.xyxy[0].tolist()  
    cls = int(box.cls[0])
    draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
    draw.text((x1, y1), str(cls), fill="red")

img_drawn.show()
../../_images/c24783e623c63af53770033afc52d14cb0689f64841a2642b69193b67ad0c985.png

Inference using the bbbox from YOLOE

question = "Given a list of regions ({bbox}), describe each region in detail.\nThe region is in the format of [x1, y1, x2, y2]."
question = question.format(bbox=str(results[0].boxes.xyxy.tolist()))
print(question)
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Given a list of regions ([[1055.907958984375, 166.67059326171875, 1199.28271484375, 232.6873779296875], [1569.0347900390625, 835.081787109375, 1620.7767333984375, 915.9642333984375], [904.4067993164062, 874.775390625, 1184.998291015625, 1012.383544921875], [1356.87841796875, 299.5640869140625, 1509.678466796875, 374.88604736328125], [718.4602661132812, 361.9792785644531, 920.3095092773438, 443.2321472167969], [809.491943359375, 542.8463745117188, 1029.044189453125, 620.627685546875], [592.3390502929688, 667.2838745117188, 843.1241455078125, 816.3665771484375], [1530.1661376953125, 395.6509094238281, 1672.77099609375, 481.7243957519531], [865.474365234375, 44.66704559326172, 1027.41259765625, 127.0682373046875], [818.4708251953125, 613.285400390625, 1088.744384765625, 725.915283203125], [813.5240478515625, 472.24884033203125, 1026.7803955078125, 566.2918090820312], [718.8458251953125, 305.9491882324219, 907.4540405273438, 377.5452575683594], [679.593994140625, 1040.403076171875, 869.1578369140625, 1076.3095703125], [566.4322509765625, 558.3406982421875, 798.9175415039062, 690.7618408203125], [566.3768920898438, 557.79541015625, 798.870849609375, 690.9448852539062], [115.34028625488281, 796.808349609375, 254.01187133789062, 907.4794921875], [907.5245361328125, 869.864501953125, 1190.736328125, 1011.8411865234375], [654.5325927734375, 154.59243774414062, 814.1592407226562, 213.017333984375], [98.18843078613281, 286.7320861816406, 125.93412780761719, 331.06549072265625], [218.403076171875, 196.6102752685547, 331.19952392578125, 360.74072265625], [882.6234741210938, 764.5292358398438, 1140.887939453125, 876.4059448242188], [881.64990234375, 765.260009765625, 1142.0595703125, 876.714111328125], [881.8800659179688, 764.2064208984375, 1140.34765625, 877.5531005859375], [493.6552734375, 156.16937255859375, 649.2734375, 229.53030395507812], [489.9690856933594, 125.33489227294922, 609.1558837890625, 165.96951293945312], [518.6258544921875, 277.63067626953125, 689.021240234375, 378.22918701171875], [515.8487548828125, 230.30799865722656, 671.5083618164062, 312.4571533203125], [217.7808837890625, 194.8688507080078, 333.8619689941406, 360.88525390625], [520.1320190429688, 278.1162109375, 689.9099731445312, 378.905029296875], [238.57997131347656, 78.4800796508789, 338.61602783203125, 132.6048583984375], [1487.9527587890625, 531.4498901367188, 1678.8438720703125, 814.0872802734375], [219.56698608398438, 214.45811462402344, 326.6293029785156, 361.3482666015625], [1488.176025390625, 531.3645629882812, 1678.5810546875, 814.4005126953125], [515.9783935546875, 231.59841918945312, 673.5609130859375, 311.1498107910156], [519.9911499023438, 280.01165771484375, 689.618896484375, 376.92181396484375], [123.79180908203125, 281.9728088378906, 148.2416534423828, 326.42376708984375], [679.5899658203125, 212.58888244628906, 889.1549072265625, 302.16455078125], [68.47206115722656, 271.57977294921875, 98.42008209228516, 320.25762939453125], [680.17138671875, 211.77749633789062, 886.9091796875, 301.7431640625], [98.52893829345703, 282.4366455078125, 126.19361114501953, 331.3118896484375]]), describe each region in detail.
The region is in the format of [x1, y1, x2, y2].
=================================================
a car parked in a parking lot.
Tokens per second: 1.18
=================================================
question = "Please describe region inside each rectangles in details."
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
the video shows a parking lot with many cars parked in it. the cars are
parked in a straight line and are of different colors and models. there are
also some other objects in the parking lot such as a truck, a van, and some
barrels. the parking lot is wet, indicating that it has been raining
recently."
Tokens per second: 3.91
=================================================
question = "Count total vehicles by type identified by each red bounding box"
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
the total number of cars is 24, the total number of trucks is 2, the total
number of vans is 1, and the total number of SUVs is 2
Tokens per second: 3.38
=================================================
question = "describe the man in detail"
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
the man is wearing a white shirt and black pants. he is holding an orange
umbrella in his right hand. he is walking towards the left side of the
frame. he is also wearing a black cap
Tokens per second: 3.39
=================================================
question = "describe all people visible in the image"
generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
There is one person visible in the image. The person is walking in the
parking lot with an umbrella.
Tokens per second: 2.65
=================================================
!nvidia-smi
Sat Nov 29 18:26:17 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.02              Driver Version: 581.15         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:02:00.0  On |                  N/A |
|  0%   45C    P8             33W /  575W |   31980MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A              33      G   /Xwayland                             N/A      |
+-----------------------------------------------------------------------------------------+
!free -h
               total        used        free      shared  buff/cache   available
Mem:            54Gi        12Gi        39Gi        89Mi       3.8Gi        42Gi
Swap:           14Gi          0B        14Gi
orig_img = Image.open("crowd.jpg")
orig_img
../../_images/6957643cfa616e4a8f300253606e792fb62fa908524a7465cf73b46a9e29b04d.png
question = "describe all people visible in the image?"
generate(media_path=orig_img, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
The image shows a large group of people walking on a crosswalk. The people
are of various ages and ethnicities. Some are carrying bags, purses, or
other items. Some are wearing casual clothing, while others are dressed
more formally.
Tokens per second: 3.51
=================================================
question = "Identify person wearing pink top and red pants carrying a pink bag. Describe the person in details such as gender, age, ethinicity, top color, bottom color, head gear, and their behavior"
generate(media_path=orig_img, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
The person in question is a woman, who appears to be of Asian descent. She
is wearing a pink top and red pants, and is carrying a pink bag. She is
walking and seems to be in a hurry.
Tokens per second: 3.37
=================================================

Grounding Test#

question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

description = "pink bag"
question = question_template.format(caption=description)

img_drawn = generate_grounding(media_path="crowd.jpg", question=question, number_of_tiles=36)
display(img_drawn)
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[503,637,581,898]
../../_images/7ed41232b654cdb3592dbf096df8cae97d20859934e0b040104977aa812a9299.png
question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

description = "young girl in blue top"
question = question_template.format(caption=description)

img_drawn = generate_grounding(media_path="crowd.jpg", question=question, number_of_tiles=36)
display(img_drawn)
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[000,243,999,999]
../../_images/e173b27335a393fd0912820ad7704cb6acc62743cd7e7cb73f66b7aef4c970f6.png
question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

description = "women in blue scarf"
question = question_template.format(caption=description)

img_drawn = generate_grounding(media_path="crowd.jpg", question=question, number_of_tiles=36)
display(img_drawn)
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[581,453,647,728]
../../_images/a7ae3e0b1371c9711fc1ae1c71475fcded340c31905d5850954fc34367570d3a.png

Alternate Crowd Image#

question_template = "Provide a bounding box of the region this sentence describes: '{caption}'.\nUse the format [x1, y1, x2, y2]."

description = "person in blue tshirt carrying bag"
question = question_template.format(caption=description)

img_drawn = generate_grounding(media_path="crowd_2.webp", question=question, number_of_tiles=36)
display(img_drawn)
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
Generating...
[543,476,616,647]
../../_images/21df262ad8e281dbc606bd05ba9417125cc34e7cfea7792fe00470c06c0e851b.png
orig_img = Image.open("crowd_2.webp")
question = "describe the person in detail carrying yellow backpack"
generate(media_path=orig_img, question=question, number_of_tiles=36, media_type="image")
INFO:root:VariableSizeImageTransform size: 448
INFO:root:ImageTransform size: 448
INFO:root:Initialized transforms with: vision_input_type: 'thumb+tile' and max_num_tiles: 36.
=================================================
The person in the image is a woman with long dark hair. She is wearing a
white shirt, black pants, and a yellow backpack. She is also carrying a
pink bag in her left hand.
Tokens per second: 16.89
=================================================

Testing Perception Encoder Core - PE-Core-L14-336#

import torch
from PIL import Image
import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms

# print("CLIP configs:", pe.CLIP.available_configs())
# CLIP configs: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224', 'PE-Core-S16-384', 'PE-Core-T16-384']

model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True)  # Downloads from HF
model = model.cuda()

preprocess = transforms.get_image_transform(model.image_size)
tokenizer = transforms.get_text_tokenizer(model.context_length)

orig_img = Image.open("condo_parking.png")
orig_img.show()

image = preprocess(orig_img).unsqueeze(0).cuda()

def test(labels):
    text = tokenizer(labels).cuda()
    with torch.no_grad(), torch.autocast("cuda"):
        image_features, text_features, logit_scale = model(image, text)
        text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)

    print("=============> ⭕ Results for labels:", labels)
    for label, prob in zip(labels, text_probs[0]):
        print(f"{label}: {prob*100}")

test(["black sedan", "blue umbrella"])
test(["man with blue umbrella", "women with orange umbrella"])
test(["black sedan", "white truck"])
test(["chairs and tables", "white barrels", "blue barrels"])
INFO:root:Missing keys for loading model: []
INFO:root:Unexpected keys for loading model: []
Missing keys for loading model: []
Unexpected keys for loading model: []
../../_images/97769185ce013910003e530aeaf57e388eacfacc3bf4705921799240774ba059.png
=============> ⭕ Results for labels: ['black sedan', 'blue umbrella']
black sedan: 0.0009818327380344272
blue umbrella: 99.9990234375
=============> ⭕ Results for labels: ['man with blue umbrella', 'women with orange umbrella']
man with blue umbrella: 58.888912200927734
women with orange umbrella: 41.111083984375
=============> ⭕ Results for labels: ['black sedan', 'white truck']
black sedan: 87.83142852783203
white truck: 12.168575286865234
=============> ⭕ Results for labels: ['chairs and tables', 'white barrels', 'blue barrels']
chairs and tables: 26.0778751373291
white barrels: 65.04957580566406
blue barrels: 8.872550010681152

Testing Siglip 2#

from transformers import AutoProcessor, SiglipModel
import torch

model_id = "google/siglip-so400m-patch14-384"
device = ("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(model_id)
model = SiglipModel.from_pretrained(model_id).to(device).eval()
/home/n03an/miniconda3/envs/perception_models/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
model
SiglipModel(
  (text_model): SiglipTextTransformer(
    (embeddings): SiglipTextEmbeddings(
      (token_embedding): Embedding(32000, 1152)
      (position_embedding): Embedding(64, 1152)
    )
    (encoder): SiglipEncoder(
      (layers): ModuleList(
        (0-26): 27 x SiglipEncoderLayer(
          (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (self_attn): SiglipAttention(
            (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
          )
          (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (mlp): SiglipMLP(
            (activation_fn): PytorchGELUTanh()
            (fc1): Linear(in_features=1152, out_features=4304, bias=True)
            (fc2): Linear(in_features=4304, out_features=1152, bias=True)
          )
        )
      )
    )
    (final_layer_norm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    (head): Linear(in_features=1152, out_features=1152, bias=True)
  )
  (vision_model): SiglipVisionTransformer(
    (embeddings): SiglipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
      (position_embedding): Embedding(729, 1152)
    )
    (encoder): SiglipEncoder(
      (layers): ModuleList(
        (0-26): 27 x SiglipEncoderLayer(
          (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (self_attn): SiglipAttention(
            (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
          )
          (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (mlp): SiglipMLP(
            (activation_fn): PytorchGELUTanh()
            (fc1): Linear(in_features=1152, out_features=4304, bias=True)
            (fc2): Linear(in_features=4304, out_features=1152, bias=True)
          )
        )
      )
    )
    (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    (head): SiglipMultiheadAttentionPoolingHead(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
      )
      (layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
      (mlp): SiglipMLP(
        (activation_fn): PytorchGELUTanh()
        (fc1): Linear(in_features=1152, out_features=4304, bias=True)
        (fc2): Linear(in_features=4304, out_features=1152, bias=True)
      )
    )
  )
)