Source code for vismatch.im_models.matchanything

import cv2
import numpy as np
from PIL import Image
from safetensors.torch import load_file
import torch
import torch.nn.functional as F

from huggingface_hub import snapshot_download
from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path, to_device, pad_images_to_same_shape, disable_xformers

# Expose the MatchAnything HF Space code (nested under imcui/third_party/MatchAnything) and its deps.
MATCHANYTHING_DIR = THIRD_PARTY_DIR.joinpath("MatchAnything", "imcui", "third_party", "MatchAnything")
add_to_path(MATCHANYTHING_DIR)

# Add ROMA's parent so ``ROMA`` is importable as a top-level namespace package,
# and ROMA itself so its internal ``from roma.models import ...`` works.
add_to_path(MATCHANYTHING_DIR.joinpath("third_party"))
add_to_path(MATCHANYTHING_DIR.joinpath("third_party", "ROMA"))

from yacs.config import CfgNode as CN  # noqa: E402
from src.loftr import LoFTR  # noqa: E402
from src.config.default import get_cfg_defaults  # noqa: E402
from ROMA.roma.matchanything_roma_model import MatchAnything_Model  # noqa: E402


def _lower_config(yacs_cfg):
    if not isinstance(yacs_cfg, CN):
        return yacs_cfg
    return {k.lower(): _lower_config(v) for k, v in yacs_cfg.items()}


[docs] class MatchAnythingMatcher(BaseMatcher): """Wrapper around the MatchAnything checkpoints.""" def __init__( self, device="cpu", variant="eloftr", match_threshold=0.2, img_resize=None, *args, **kwargs, ): super().__init__(device, **kwargs) self.variant = variant.lower() if self.variant not in ("eloftr", "roma"): raise ValueError(f"Unsupported MatchAnything variant: {variant}") self.match_threshold = match_threshold self.img_resize = img_resize self.model_name = f"matchanything_{self.variant}" self._load_model() if device == "cpu": disable_xformers() def _load_model(self): # Ensure MatchAnything's ``src`` is resolvable even when another # matcher was loaded between module import and this instantiation. add_to_path(MATCHANYTHING_DIR) cfg = get_cfg_defaults() if self.variant == "eloftr": cfg.merge_from_file(str(MATCHANYTHING_DIR.joinpath("configs", "models", "eloftr_model.py"))) if cfg.DATASET.NPE_NAME is not None: if cfg.DATASET.NPE_NAME == "megadepth": target_size = self.img_resize or 832 cfg.LOFTR.COARSE.NPE = [832, 832, target_size, target_size] else: cfg.merge_from_file(str(MATCHANYTHING_DIR.joinpath("configs", "models", "roma_model.py"))) if self.device == "cpu": cfg.LOFTR.FP16 = False cfg.ROMA.MODEL.AMP = False cfg.METHOD = self.model_name cfg.LOFTR.MATCH_COARSE.THR = self.match_threshold cfg_lower = _lower_config(cfg) if self.variant == "eloftr": self.net = LoFTR(config=cfg_lower["loftr"]) else: assert self.device != "mps", ( f"Device must be 'cpu' or 'cuda' for {self.name}. Device='{self.device}' not supported" ) self.net = MatchAnything_Model(config=cfg_lower["roma"], test_mode=True) weights_path = f"{snapshot_download(f'vismatch/matchanything-{self.variant}')}/model.safetensors" state_dict = load_file(weights_path) self.net.load_state_dict(state_dict, strict=False) self.net.eval().to(self.device)
[docs] def preprocess(self, img): img_np = img.cpu().numpy().squeeze() * 255 img_np = img_np.transpose(1, 2, 0).astype("uint8") img_size = np.array(img_np.shape[:2]) img_gray = np.array(Image.fromarray(img_np).convert("L")) img_resized, scale_hw, mask = resize(img_gray, df=32) img_tensor = torch.from_numpy(img_resized)[None][None] / 255.0 return img_tensor, img_size, scale_hw, mask, img
def _forward(self, img0, img1): img0_proc, img0_size, img0_scale, mask0, img0_orig = self.preprocess(img0) img1_proc, img1_size, img1_scale, mask1, img1_orig = self.preprocess(img1) # Pad grayscale images to same shape for eloftr and disable masks # (third-party attention code can't handle masks with different True regions) if self.variant == "eloftr": img0_proc, img1_proc = pad_images_to_same_shape(img0_proc, img1_proc) mask0 = None mask1 = None batch = { "image0": img0_proc, "image1": img1_proc, # ROMA expects a leading batch dim on RGB images; keep it for both variants "image0_rgb_origin": img0_orig[None], "image1_rgb_origin": img1_orig[None], "origin_img_size0": torch.from_numpy(img0_size)[None], "origin_img_size1": torch.from_numpy(img1_size)[None], } if mask0 is not None and mask1 is not None: mask0_t = torch.from_numpy(mask0).to(self.device) mask1_t = torch.from_numpy(mask1).to(self.device) ts_mask_0 = F.interpolate( mask0_t[None, None].float(), scale_factor=0.125, mode="nearest", recompute_scale_factor=False, )[0, 0].bool() ts_mask_1 = F.interpolate( mask1_t[None, None].float(), scale_factor=0.125, mode="nearest", recompute_scale_factor=False, )[0, 0].bool() batch["mask0"] = ts_mask_0[None] batch["mask1"] = ts_mask_1[None] batch = to_device(batch, device=self.device) self.net(batch) mkpts0 = batch["mkpts0_f"].detach().cpu() mkpts1 = batch["mkpts1_f"].detach().cpu() if self.variant == "eloftr": mkpts0 *= torch.tensor(img0_scale)[[1, 0]] mkpts1 *= torch.tensor(img1_scale)[[1, 0]] return mkpts0, mkpts1, None, None, None, None
# Custom resize logic from MatchAnything to preserve padding/masks expected by the upstream config.
[docs] def resize(img, resize=None, df=8, padding=True): w, h = img.shape[1], img.shape[0] w_new, h_new = process_resize(w, h, resize=resize, df=df, resize_no_larger_than=False) img_new = resize_image(img, (w_new, h_new), interp="pil_LANCZOS").astype("float32") h_scale, w_scale = img.shape[0] / img_new.shape[0], img.shape[1] / img_new.shape[1] mask = None if padding: img_new, mask = pad_bottom_right(img_new, max(h_new, w_new), ret_mask=True) return img_new, [h_scale, w_scale], mask
[docs] def process_resize(w, h, resize=None, df=None, resize_no_larger_than=False): if resize is not None: assert len(resize) > 0 and len(resize) <= 2 if resize_no_larger_than and (max(h, w) <= max(resize)): w_new, h_new = w, h else: if len(resize) == 1 and resize[0] > -1: # resize the larger side scale = resize[0] / max(h, w) w_new, h_new = int(round(w * scale)), int(round(h * scale)) elif len(resize) == 1 and resize[0] == -1: w_new, h_new = w, h else: w_new, h_new = resize[0], resize[1] else: w_new, h_new = w, h if df is not None: w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new]) return w_new, h_new
[docs] def resize_image(image, size, interp): if interp.startswith("cv2_"): interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper()) h, w = image.shape[:2] if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): interp = cv2.INTER_LINEAR resized = cv2.resize(image, size, interpolation=interp) elif interp.startswith("pil_"): interp = getattr(Image, interp[len("pil_") :].upper()) resized = Image.fromarray(image.astype(np.uint8)) resized = resized.resize(size, resample=interp) resized = np.asarray(resized, dtype=image.dtype) else: raise ValueError(f"Unknown interpolation {interp}.") return resized
[docs] def pad_bottom_right(inp, pad_size, ret_mask=False): assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" mask = None if inp.ndim == 2: padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) padded[: inp.shape[0], : inp.shape[1]] = inp if ret_mask: mask = np.zeros((pad_size, pad_size), dtype=bool) mask[: inp.shape[0], : inp.shape[1]] = True elif inp.ndim == 3: padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) padded[:, : inp.shape[1], : inp.shape[2]] = inp if ret_mask: mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) mask[:, : inp.shape[1], : inp.shape[2]] = True mask = mask[0] if mask is not None else None else: raise NotImplementedError() return padded, mask