from pathlib import Path
import torchvision.transforms as tfm
import py3_wget
import numpy as np
import torch
from huggingface_hub import snapshot_download
# Monkey patch torch.load to use weights_only=False by default for compatibility with PyTorch 2.6+
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import resize_to_divisible, add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath("mast3r"))
from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
from dust3r.inference import inference
[docs]
class Mast3rMatcher(BaseMatcher):
hf_model_id = "vismatch/master"
weight_filename = "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
vit_patch_size = 16
def __init__(self, device="cpu", *args, **kwargs):
super().__init__(device, **kwargs)
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.verbose = False
model_path = self.download_weights()
self.model = AsymmetricMASt3R.from_pretrained(model_path).to(device)
[docs]
@classmethod
def download_weights(cls):
url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
cache_dir = Path(snapshot_download(cls.hf_model_id))
model_path = cache_dir / cls.weight_filename
if not model_path.is_file():
print("Downloading Master(ViT large)... (takes a while)")
py3_wget.download_file(url, model_path)
return model_path
[docs]
def preprocess(self, img):
_, h, w = img.shape
orig_shape = h, w
img = resize_to_divisible(img, self.vit_patch_size)
img = self.normalize(img).unsqueeze(0)
return img, orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
img_pair = [
{"img": img0, "idx": 0, "instance": 0, "true_shape": np.int32([img0.shape[-2:]])},
{"img": img1, "idx": 1, "instance": 1, "true_shape": np.int32([img1.shape[-2:]])},
]
output = inference([tuple(img_pair)], self.model, self.device, batch_size=1, verbose=False)
# at this stage, you have the raw dust3r predictions
view1, pred1 = output["view1"], output["pred1"]
view2, pred2 = output["view2"], output["pred2"]
desc1, desc2 = pred1["desc"].squeeze(0).detach(), pred2["desc"].squeeze(0).detach()
# find 2D-2D matches between the two images
matches_im0, matches_im1 = fast_reciprocal_NNs(
desc1, desc2, subsample_or_initxy1=8, device=self.device, dist="dot", block_size=2**13
)
# ignore small border around the edge
H0, W0 = view1["true_shape"][0]
valid_matches_im0 = (
(matches_im0[:, 0] >= 3)
& (matches_im0[:, 0] < int(W0) - 3)
& (matches_im0[:, 1] >= 3)
& (matches_im0[:, 1] < int(H0) - 3)
)
H1, W1 = view2["true_shape"][0]
valid_matches_im1 = (
(matches_im1[:, 0] >= 3)
& (matches_im1[:, 0] < int(W1) - 3)
& (matches_im1[:, 1] >= 3)
& (matches_im1[:, 1] < int(H1) - 3)
)
valid_matches = valid_matches_im0 & valid_matches_im1
mkpts0, mkpts1 = matches_im0[valid_matches], matches_im1[valid_matches]
# duster sometimes requires reshaping an image to fit vit patch size evenly, so we need to
# rescale kpts to the original img
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
return mkpts0, mkpts1, None, None, None, None