import torch
from pathlib import Path
from torchvision.datasets.utils import download_file_from_google_drive
from huggingface_hub import snapshot_download
from vismatch import BaseMatcher
[docs]
class xFeatSteerersMatcher(BaseMatcher):
"""
Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing
Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing
"""
steer_permutations = [torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64) for k in range(4)]
hf_model_ids = {
"perm": "vismatch/xfeat-steerers-perm",
"learned": "vismatch/xfeat-steerers-learned",
}
perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt"
learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r"
steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv"
def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs):
super().__init__(device, **kwargs)
if mode not in ["sparse", "semi-dense"]:
raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]')
if mode != "semi-dense":
assert self.device != "mps", (
f"Device must be 'cpu' or 'cuda' for {self.name} with mode {mode}. Device='{self.device}' not supported"
)
self.steerer_type = steerer_type
if self.steerer_type not in ["learned", "perm"]:
raise ValueError(
f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.'
)
cache_dir = Path(snapshot_download(self.hf_model_ids[steerer_type]))
self.perm_weights_path = cache_dir / "xfeat_perm_steer.pth"
self.learned_weights_path = cache_dir / "xfeat_learn_steer.pth"
self.steerer_weights_path = cache_dir / "xfeat_learn_steer_steerer.pth"
self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints)
self.download_weights(cache_dir)
# Load xfeat-fixed-perm-steerers weights
state_dict = torch.load(self.weights_path, map_location="cpu", weights_only=True)
for k in list(state_dict):
state_dict["net." + k] = state_dict[k]
del state_dict[k]
self.model.load_state_dict(state_dict)
self.model.to(device)
self.model.dev = device
if steerer_type == "learned":
self.steerer = torch.nn.Linear(64, 64, bias=False)
self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location="cpu", weights_only=True)[
"weight"
][..., 0, 0]
self.steerer.eval()
self.steerer.to(device)
else:
self.steer_permutations = [perm.to(device) for perm in self.steer_permutations]
self.max_num_keypoints = max_num_keypoints
self.mode = mode
self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9)
[docs]
def download_weights(self, cache_dir):
if self.steerer_type == "perm":
self.weights_path = self.perm_weights_path
if not self.perm_weights_path.exists():
download_file_from_google_drive(
self.perm_weights_gdrive_id, root=cache_dir, filename=self.perm_weights_path.name
)
if self.steerer_type == "learned":
self.weights_path = self.learned_weights_path
if not self.learned_weights_path.exists():
download_file_from_google_drive(
self.learned_weights_gdrive_id, root=cache_dir, filename=self.learned_weights_path.name
)
if not self.steerer_weights_path.exists():
download_file_from_google_drive(
self.steerer_weights_drive_id, root=cache_dir, filename=self.steerer_weights_path.name
)
[docs]
def preprocess(self, img: torch.Tensor) -> torch.Tensor:
img = self.model.parse_input(img)
if self.device == "cuda" and self.mode == "semi-dense" and img.dtype == torch.uint8:
img = img / 255 # cuda error in upsample_bilinear_2d_out_frame if img is ubyte
return img
def _forward(self, img0, img1):
img0, img1 = self.preprocess(img0), self.preprocess(img1)
if self.mode == "semi-dense":
output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints)
output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints)
rot0to1 = 0
idxs_list = self.model.batch_match(
output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim
)
descriptors0 = output0["descriptors"].clone()
for r in range(1, 4):
if self.steerer_type == "learned":
descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1)
else:
descriptors0 = output0["descriptors"][..., self.steer_permutations[r]]
new_idxs_list = self.model.batch_match(descriptors0, output1["descriptors"], min_cossim=self.min_cossim)
if len(new_idxs_list[0][0]) > len(idxs_list[0][0]):
idxs_list = new_idxs_list
rot0to1 = r
# align to first image for refinement MLP
if self.steerer_type == "learned":
if rot0to1 > 0:
for _ in range(4 - rot0to1):
output1["descriptors"] = self.steerer(
output1["descriptors"]
) # Adding normalization here hurts performance for some reason, probably due to the way it's done during training
else:
output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]]
matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0)
mkpts0, mkpts1 = matches[:, :2], matches[:, 2:]
else:
output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0]
output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0]
idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
rot0to1 = 0
for r in range(1, 4):
if self.steerer_type == "learned":
output0["descriptors"] = torch.nn.functional.normalize(self.steerer(output0["descriptors"]), dim=-1)
output0_steered_descriptors = output0["descriptors"]
else:
output0_steered_descriptors = output0["descriptors"][..., self.steer_permutations[r]]
new_idxs0, new_idxs1 = self.model.match(
output0_steered_descriptors, output1["descriptors"], min_cossim=self.min_cossim
)
if len(new_idxs0) > len(idxs0):
idxs0 = new_idxs0
idxs1 = new_idxs1
rot0to1 = r
mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1]
return (
mkpts0,
mkpts1,
output0["keypoints"].squeeze(),
output1["keypoints"].squeeze(),
output0["descriptors"].squeeze(),
output1["descriptors"].squeeze(),
)