Source code for vismatch.im_models.liftfeat

import torch
import numpy as np

from huggingface_hub import snapshot_download
from vismatch import THIRD_PARTY_DIR, BaseMatcher
from vismatch.utils import to_numpy, add_to_path

add_to_path(THIRD_PARTY_DIR.joinpath("LiftFeat"))
from models.liftfeat_wrapper import LiftFeat


[docs] class LiftFeatMatcher(BaseMatcher): def __init__(self, device="cpu", detect_threshold=0.05, *args, **kwargs): super().__init__(device, **kwargs) self.detect_threshold = detect_threshold weights_path = f"{snapshot_download('vismatch/liftfeat')}/liftfeat.pth" self.model = LiftFeat(weight=weights_path, detect_threshold=self.detect_threshold)
[docs] def preprocess(self, img): "LiftFeat requires input as raw ndarray (result of cv2.imread)" # convert axis, (C x H x W) -> (H x W x C) if isinstance(img, torch.Tensor): img = to_numpy(img).transpose(1, 2, 0) assert isinstance(img, np.ndarray), "LiftFeatModel requires input as np.ndarray" orig_shape = img.shape[:2] return img, orig_shape
def _forward(self, img0, img1): img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) mkpts0, mkpts1 = self.model.match_liftfeat(img0, img1) return ( mkpts0, mkpts1, None, # keypoints_0 None, # keypoints_1 None, # desc0 None, # desc1 )