Source code for vismatch.im_models.ufm

import torch

from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import flow_to_matches, add_to_path
from skimage import img_as_ubyte
from kornia import tensor_to_image
import numpy as np

add_to_path(THIRD_PARTY_DIR / "UFM")
from uniflowmatch.models.ufm import UniFlowMatchClassificationRefinement


[docs] class UFMMatcher(BaseMatcher): def __init__(self, device="cpu", max_num_keypoints=1024, min_confidence=0.2, *args, **kwargs): super().__init__(device, **kwargs) assert "cuda" in self.device or self.device == "cpu", ( f"Device must be 'cpu' or 'cuda' for {self.name}. Device='{self.device}' not supported" ) self.model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") self.model = self.model.eval().to(self.device) self.max_num_keypoints = max_num_keypoints self.min_confidence = min_confidence # minimum confidence threshold for matches
[docs] def preprocess(self, img) -> torch.Tensor: # output needs to be a tensor of shape (H, W, 3) _, h, w = img.shape orig_shape = h, w if isinstance(img, torch.Tensor): img = tensor_to_image(img) img = img_as_ubyte(np.clip(img, 0, 1)) assert img.dtype == np.uint8, "Image must be uint8" assert img.ndim == 3 and img.shape[2] == 3, "Image must be HxWx3" return torch.from_numpy(img).to(self.device), orig_shape
def _forward(self, img0, img1): img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) result = self.model.predict_correspondences_batched( source_image=img0, target_image=img1, ) flow = result.flow.flow_output[0].cpu().numpy() covisibility = result.covisibility.mask[0].cpu().numpy() # postprocess model output to get kpts, desc, etc mkpts0, mkpts1, confidences = flow_to_matches( flow, covisibility, min_confidence=self.min_confidence, num_samples=self.max_num_keypoints ) # if we had to resize the img to divisible, then rescale the kpts back to input img size H0, W0, H1, W1 = *img0.shape[:2], *img1.shape[:2] mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) return mkpts0, mkpts1, None, None, None, None