Source code for vismatch.im_models.romav2

import torch

from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path

add_to_path(THIRD_PARTY_DIR.joinpath("RoMaV2/src"))

from romav2 import RoMaV2  # noqa: E402
import romav2.device as romav2_device  # noqa: E402


[docs] class RoMaV2Matcher(BaseMatcher): def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): super().__init__(device, **kwargs) assert "cuda" in self.device, f"Device must be 'cuda' for {self.name}. Device='{self.device}' not supported" # Temporarily override the global device for proper initialization original_device = romav2_device.device romav2_device.device = torch.device(device) try: # Disable compilation to avoid dtype issues cfg = RoMaV2.Cfg(compile=False) self.romav2_model = RoMaV2(cfg=cfg) # Load pretrained weights (not loaded automatically when custom cfg is provided) weights = torch.hub.load_state_dict_from_url( "https://github.com/Parskatt/RoMaV2/releases/download/weights/romav2.pt" ) self.romav2_model.load_state_dict(weights) finally: # Restore original device romav2_device.device = original_device # Convert to float32 for better CPU compatibility (bfloat16 not fully supported on CPU) self.romav2_model = self.romav2_model.float() self.romav2_model.train(False) # Move all components to the specified device AFTER everything is initialized # This ensures all lazy-initialized parameters/buffers are also moved self.romav2_model = self.romav2_model.to(torch.device(device)) self.max_keypoints = max_num_keypoints
[docs] def preprocess(self, img): return img.unsqueeze(0)
def _forward(self, img0, img1): img0 = self.preprocess(img0) img1 = self.preprocess(img1) img0 = img0.to(self.device) img1 = img1.to(self.device) h0, w0 = img0.shape[-2:] h1, w1 = img1.shape[-2:] preds = self.romav2_model.match(img0, img1) matches, confidence, precision_AB, precision_BA = self.romav2_model.sample(preds, self.max_keypoints) mkpts0, mkpts1 = self.romav2_model.to_pixel_coordinates(matches, h0, w0, h1, w1) return mkpts0, mkpts1, None, None, None, None