Source code for vismatch.im_models.topicfm

import torch
import torchvision.transforms as tfm
from safetensors.torch import load_file

from huggingface_hub import snapshot_download
from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path, resize_to_divisible

add_to_path(THIRD_PARTY_DIR.joinpath("TopicFM"))

from src.models.topic_fm import TopicFM  # noqa: E402
from src import get_model_cfg  # noqa: E402


[docs] class TopicFMMatcher(BaseMatcher): divisible_size = 16 def __init__(self, device="cpu", variant="fast", *args, **kwargs): """TopicFM matcher. Args: device (str, optional): 'cpu' or 'cuda'. Defaults to "cpu". variant (str, optional): 'fast' or 'plus' - which pretrained model to use. Defaults to "fast". """ super().__init__(device, **kwargs) self.variant = variant # Load model config conf = dict(get_model_cfg()) # Set config based on variant if variant == "fast": conf["coarse"]["nhead"] = 2 conf["coarse"]["attention"] = "full" conf["coarse"]["n_samples"] = 0 conf["match_coarse"]["thr"] = 0.2 elif variant == "plus": conf["coarse"]["nhead"] = 8 conf["coarse"]["attention"] = "linear" conf["coarse"]["n_samples"] = 8 conf["coarse"]["n_topic_transformers"] = 2 conf["match_coarse"]["thr"] = 0.25 else: raise ValueError(f"Unknown variant: {variant}. Must be 'fast' or 'plus'") conf["loss"]["fine_type"] = "sym_epi" # Initialize model self.model = TopicFM(config=conf) # Download and load pretrained weights repo = "vismatch/topicfm" if self.variant == "fast" else "vismatch/topicfm-plus" weights_path = f"{snapshot_download(repo)}/model.safetensors" self.model.load_state_dict(load_file(weights_path)) self.model = self.model.eval().to(self.device)
[docs] def preprocess(self, img): """Convert RGB image to grayscale, resize to divisible, and add batch dimension.""" _, h, w = img.shape orig_shape = h, w img = resize_to_divisible(img, self.divisible_size) return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1): # Preprocess images img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) img0 = img0.to(self.device) img1 = img1.to(self.device) # Prepare data dict data = { "image0": img0, "image1": img1, } # Run model with torch.no_grad(): self.model(data) # Extract matched keypoints mkpts0 = data["mkpts0_f"] mkpts1 = data["mkpts1_f"] H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) return mkpts0, mkpts1, None, None, None, None