import torch
from huggingface_hub import snapshot_download
from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath("LISRD"))
from lisrd.models.lisrd import Lisrd
from lisrd.models.base_model import Mode
from lisrd.utils.geometry_utils import extract_descriptors, lisrd_matcher
add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue"))
from lightglue import ALIKED, SuperPoint, SIFT
[docs]
class LISRDMatcher(BaseMatcher):
# Load the LISRD model
model_config = {
"name": "lisrd",
"desc_size": 128,
"tile": 3,
"n_clusters": 8,
"meta_desc_dim": 128,
"learning_rate": 0.001,
"compute_meta_desc": True,
"freeze_local_desc": False,
}
def __init__(
self,
device="cpu",
detector="superpoint",
max_num_keypoints=4096,
*args,
**kwargs,
):
super().__init__(device, **kwargs)
print("WARNING: LISRD may take awhile to load.")
model_path = f"{snapshot_download('vismatch/lisrd')}/lisrd_vidit.pth"
self.model = Lisrd(None, self.model_config, device)
self.model.load(model_path, Mode.EXPORT)
# On multi-GPU machines, BaseModel wraps _net in DataParallel even when
# targeting CPU. Unwrap it so inference works correctly on CPU.
if device == "cpu" and isinstance(self.model._net, torch.nn.DataParallel):
self.model._net = self.model._net.module
self.model._net.eval()
detector = detector.lower()
if detector == "aliked":
self.extractor = ALIKED(max_num_keypoints=max_num_keypoints)
elif detector == "sift":
self.extractor = SIFT(max_num_keypoints=max_num_keypoints)
else:
self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints)
self.extractor = self.extractor.eval().to(device)
[docs]
def preprocess(self, img: torch.Tensor) -> torch.Tensor:
_, h, w = img.shape
orig_shape = h, w
return img.unsqueeze(0).to(self.device), orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
# Keypoint detection
keypoints0 = self.extractor.extract(img0)["keypoints"].squeeze()
keypoints1 = self.extractor.extract(img1)["keypoints"].squeeze()
# Descriptor inference
outputs0 = self.model._forward({"image0": img0}, Mode.EXPORT, self.model_config)
desc0 = outputs0["descriptors"]
meta_desc0 = outputs0["meta_descriptors"]
outputs1 = self.model._forward({"image0": img1}, Mode.EXPORT, self.model_config)
desc1 = outputs1["descriptors"]
meta_desc1 = outputs1["meta_descriptors"]
# Sample the descriptors at the keypoint positions
# keypoints_to_grid expects (row, col) order; extractors return (x, y) = (col, row)
desc0, meta_desc0 = extract_descriptors(keypoints0[:, [1, 0]], desc0, meta_desc0, img0_orig_shape)
desc1, meta_desc1 = extract_descriptors(keypoints1[:, [1, 0]], desc1, meta_desc1, img1_orig_shape)
matches = lisrd_matcher(desc0, desc1, meta_desc0, meta_desc1).cpu().numpy()
mkpts0, mkpts1 = (
keypoints0[matches[:, 0]],
keypoints1[matches[:, 1]],
)
return (
mkpts0,
mkpts1,
keypoints0,
keypoints1,
None, # desc0,
None, # desc1,
)
# lisrd has N x 4 x D dimensional descriptors, inconsistent with other methods, hence return None as descs