Source code for vismatch.im_models.lisrd

import torch

from huggingface_hub import snapshot_download
from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path

add_to_path(THIRD_PARTY_DIR.joinpath("LISRD"))
from lisrd.models.lisrd import Lisrd
from lisrd.models.base_model import Mode
from lisrd.utils.geometry_utils import extract_descriptors, lisrd_matcher

add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue"))

from lightglue import ALIKED, SuperPoint, SIFT


[docs] class LISRDMatcher(BaseMatcher): # Load the LISRD model model_config = { "name": "lisrd", "desc_size": 128, "tile": 3, "n_clusters": 8, "meta_desc_dim": 128, "learning_rate": 0.001, "compute_meta_desc": True, "freeze_local_desc": False, } def __init__( self, device="cpu", detector="superpoint", max_num_keypoints=4096, *args, **kwargs, ): super().__init__(device, **kwargs) print("WARNING: LISRD may take awhile to load.") model_path = f"{snapshot_download('vismatch/lisrd')}/lisrd_vidit.pth" self.model = Lisrd(None, self.model_config, device) self.model.load(model_path, Mode.EXPORT) # On multi-GPU machines, BaseModel wraps _net in DataParallel even when # targeting CPU. Unwrap it so inference works correctly on CPU. if device == "cpu" and isinstance(self.model._net, torch.nn.DataParallel): self.model._net = self.model._net.module self.model._net.eval() detector = detector.lower() if detector == "aliked": self.extractor = ALIKED(max_num_keypoints=max_num_keypoints) elif detector == "sift": self.extractor = SIFT(max_num_keypoints=max_num_keypoints) else: self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints) self.extractor = self.extractor.eval().to(device)
[docs] def preprocess(self, img: torch.Tensor) -> torch.Tensor: _, h, w = img.shape orig_shape = h, w return img.unsqueeze(0).to(self.device), orig_shape
def _forward(self, img0, img1): img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) # Keypoint detection keypoints0 = self.extractor.extract(img0)["keypoints"].squeeze() keypoints1 = self.extractor.extract(img1)["keypoints"].squeeze() # Descriptor inference outputs0 = self.model._forward({"image0": img0}, Mode.EXPORT, self.model_config) desc0 = outputs0["descriptors"] meta_desc0 = outputs0["meta_descriptors"] outputs1 = self.model._forward({"image0": img1}, Mode.EXPORT, self.model_config) desc1 = outputs1["descriptors"] meta_desc1 = outputs1["meta_descriptors"] # Sample the descriptors at the keypoint positions # keypoints_to_grid expects (row, col) order; extractors return (x, y) = (col, row) desc0, meta_desc0 = extract_descriptors(keypoints0[:, [1, 0]], desc0, meta_desc0, img0_orig_shape) desc1, meta_desc1 = extract_descriptors(keypoints1[:, [1, 0]], desc1, meta_desc1, img1_orig_shape) matches = lisrd_matcher(desc0, desc1, meta_desc0, meta_desc1).cpu().numpy() mkpts0, mkpts1 = ( keypoints0[matches[:, 0]], keypoints1[matches[:, 1]], ) return ( mkpts0, mkpts1, keypoints0, keypoints1, None, # desc0, None, # desc1, )
# lisrd has N x 4 x D dimensional descriptors, inconsistent with other methods, hence return None as descs