Source code for vismatch.im_models.matchformer

import torchvision.transforms as tfm
from safetensors.torch import load_file

from huggingface_hub import snapshot_download
from vismatch import THIRD_PARTY_DIR, BaseMatcher
from vismatch.utils import resize_to_divisible, lower_config, add_to_path, pad_images_to_same_shape

add_to_path(THIRD_PARTY_DIR.joinpath("MatchFormer"))

from model.matchformer import Matchformer
from config.defaultmf import get_cfg_defaults as mf_cfg_defaults


[docs] class MatchformerMatcher(BaseMatcher): divisible_size = 32 def __init__(self, device="cpu", **kwargs): super().__init__(device, **kwargs) self.matcher = self.load_model().to(device).eval()
[docs] def load_model(self, cfg_path=None): config = mf_cfg_defaults() if cfg_path is not None: config.merge_from_file(cfg_path) config.MATCHFORMER.BACKBONE_TYPE = "largela" config.MATCHFORMER.SCENS = "outdoor" config.MATCHFORMER.RESOLUTION = (8, 2) config.MATCHFORMER.COARSE.D_MODEL = 256 config.MATCHFORMER.COARSE.D_FFN = 256 matcher = Matchformer(config=lower_config(config)["matchformer"]) weights_path = f"{snapshot_download('vismatch/matchformer')}/matchformer_outdoor-large-LA.safetensors" state_dict = load_file(weights_path) matcher.load_state_dict({k.replace("matcher.", ""): v for k, v in state_dict.items()}) return matcher
[docs] def preprocess(self, img): _, h, w = img.shape orig_shape = h, w img = resize_to_divisible(img, self.divisible_size) return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1): img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) H0, W0 = img0.shape[-2:] H1, W1 = img1.shape[-2:] img0, img1 = pad_images_to_same_shape(img0, img1) batch = {"image0": img0, "image1": img1} self.matcher(batch) mkpts0 = batch["mkpts0_f"] mkpts1 = batch["mkpts1_f"] mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) return mkpts0, mkpts1, None, None, None, None