Source code for vismatch.im_models.roma

import torch
import torchvision.transforms as tfm
from kornia.augmentation import PadTo
from kornia.utils import tensor_to_image
import tempfile
from pathlib import Path


from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path, disable_xformers

add_to_path(THIRD_PARTY_DIR.joinpath("RoMa"))
from romatch import roma_outdoor, tiny_roma_v1_outdoor

from PIL import Image
from skimage.util import img_as_ubyte


[docs] class RomaMatcher(BaseMatcher): dino_patch_size = 14 coarse_ratio = 560 / 864 def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): super().__init__(device, **kwargs) self.roma_model = roma_outdoor(device=device) self.max_keypoints = max_num_keypoints self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.roma_model.train(False) if device == "cpu": disable_xformers()
[docs] def compute_padding(self, img0, img1): _, h0, w0 = img0.shape _, h1, w1 = img1.shape pad_dim = max(h0, w0, h1, w1) self.pad = PadTo((pad_dim, pad_dim), keepdim=True)
[docs] def preprocess(self, img: torch.Tensor, pad=False) -> Image: if isinstance(img, torch.Tensor) and img.dtype == (torch.float): img = torch.clamp(img, -1, 1) if pad: img = self.pad(img) img = tensor_to_image(img) pil_img = Image.fromarray(img_as_ubyte(img), mode="RGB") temp = tempfile.NamedTemporaryFile("w+b", suffix=".png", delete=False) pil_img.save(temp.name, format="png") return temp, pil_img.size
def _forward(self, img0, img1, pad=False): if pad: self.compute_padding(img0, img1) img0_temp, img0_size = self.preprocess(img0) img1_temp, img1_size = self.preprocess(img1) w0, h0 = img0_size w1, h1 = img1_size warp, certainty = self.roma_model.match(img0_temp.name, img1_temp.name, batched=False, device=self.device) img0_temp.close(), img1_temp.close() Path(img0_temp.name).unlink() Path(img1_temp.name).unlink() matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints) mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1) return mkpts0, mkpts1, None, None, None, None
[docs] class TinyRomaMatcher(BaseMatcher): def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): super().__init__(device, **kwargs) self.roma_model = tiny_roma_v1_outdoor(device=device) self.max_keypoints = max_num_keypoints self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.roma_model.train(False)
[docs] def preprocess(self, img): return self.normalize(img).unsqueeze(0)
def _forward(self, img0, img1): img0 = self.preprocess(img0) img1 = self.preprocess(img1) h0, w0 = img0.shape[-2:] h1, w1 = img1.shape[-2:] # batch = {"im_A": img0.to(self.device), "im_B": img1.to(self.device)} warp, certainty = self.roma_model.match(img0, img1, batched=False) matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints) mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1) return mkpts0, mkpts1, None, None, None, None