Source code for vismatch.im_models.xfeat_steerers

import torch
from pathlib import Path
from torchvision.datasets.utils import download_file_from_google_drive
from huggingface_hub import snapshot_download

from vismatch import BaseMatcher


[docs] class xFeatSteerersMatcher(BaseMatcher): """ Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing """ steer_permutations = [torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64) for k in range(4)] hf_model_ids = { "perm": "vismatch/xfeat-steerers-perm", "learned": "vismatch/xfeat-steerers-learned", } perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt" learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r" steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv" def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs): super().__init__(device, **kwargs) if mode not in ["sparse", "semi-dense"]: raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]') if mode != "semi-dense": assert self.device != "mps", ( f"Device must be 'cpu' or 'cuda' for {self.name} with mode {mode}. Device='{self.device}' not supported" ) self.steerer_type = steerer_type if self.steerer_type not in ["learned", "perm"]: raise ValueError( f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.' ) cache_dir = Path(snapshot_download(self.hf_model_ids[steerer_type])) self.perm_weights_path = cache_dir / "xfeat_perm_steer.pth" self.learned_weights_path = cache_dir / "xfeat_learn_steer.pth" self.steerer_weights_path = cache_dir / "xfeat_learn_steer_steerer.pth" self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints) self.download_weights(cache_dir) # Load xfeat-fixed-perm-steerers weights state_dict = torch.load(self.weights_path, map_location="cpu", weights_only=True) for k in list(state_dict): state_dict["net." + k] = state_dict[k] del state_dict[k] self.model.load_state_dict(state_dict) self.model.to(device) self.model.dev = device if steerer_type == "learned": self.steerer = torch.nn.Linear(64, 64, bias=False) self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location="cpu", weights_only=True)[ "weight" ][..., 0, 0] self.steerer.eval() self.steerer.to(device) else: self.steer_permutations = [perm.to(device) for perm in self.steer_permutations] self.max_num_keypoints = max_num_keypoints self.mode = mode self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9)
[docs] def download_weights(self, cache_dir): if self.steerer_type == "perm": self.weights_path = self.perm_weights_path if not self.perm_weights_path.exists(): download_file_from_google_drive( self.perm_weights_gdrive_id, root=cache_dir, filename=self.perm_weights_path.name ) if self.steerer_type == "learned": self.weights_path = self.learned_weights_path if not self.learned_weights_path.exists(): download_file_from_google_drive( self.learned_weights_gdrive_id, root=cache_dir, filename=self.learned_weights_path.name ) if not self.steerer_weights_path.exists(): download_file_from_google_drive( self.steerer_weights_drive_id, root=cache_dir, filename=self.steerer_weights_path.name )
[docs] def preprocess(self, img: torch.Tensor) -> torch.Tensor: img = self.model.parse_input(img) if self.device == "cuda" and self.mode == "semi-dense" and img.dtype == torch.uint8: img = img / 255 # cuda error in upsample_bilinear_2d_out_frame if img is ubyte return img
def _forward(self, img0, img1): img0, img1 = self.preprocess(img0), self.preprocess(img1) if self.mode == "semi-dense": output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints) output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints) rot0to1 = 0 idxs_list = self.model.batch_match( output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim ) descriptors0 = output0["descriptors"].clone() for r in range(1, 4): if self.steerer_type == "learned": descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1) else: descriptors0 = output0["descriptors"][..., self.steer_permutations[r]] new_idxs_list = self.model.batch_match(descriptors0, output1["descriptors"], min_cossim=self.min_cossim) if len(new_idxs_list[0][0]) > len(idxs_list[0][0]): idxs_list = new_idxs_list rot0to1 = r # align to first image for refinement MLP if self.steerer_type == "learned": if rot0to1 > 0: for _ in range(4 - rot0to1): output1["descriptors"] = self.steerer( output1["descriptors"] ) # Adding normalization here hurts performance for some reason, probably due to the way it's done during training else: output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]] matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0) mkpts0, mkpts1 = matches[:, :2], matches[:, 2:] else: output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0] output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0] idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim) rot0to1 = 0 for r in range(1, 4): if self.steerer_type == "learned": output0["descriptors"] = torch.nn.functional.normalize(self.steerer(output0["descriptors"]), dim=-1) output0_steered_descriptors = output0["descriptors"] else: output0_steered_descriptors = output0["descriptors"][..., self.steer_permutations[r]] new_idxs0, new_idxs1 = self.model.match( output0_steered_descriptors, output1["descriptors"], min_cossim=self.min_cossim ) if len(new_idxs0) > len(idxs0): idxs0 = new_idxs0 idxs1 = new_idxs1 rot0to1 = r mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1] return ( mkpts0, mkpts1, output0["keypoints"].squeeze(), output1["keypoints"].squeeze(), output0["descriptors"].squeeze(), output1["descriptors"].squeeze(), )