Source code for vismatch.im_models.aspanformer

import torch
from pathlib import Path
import gdown
import torchvision.transforms as tfm
import tarfile
from huggingface_hub import snapshot_download

from vismatch import THIRD_PARTY_DIR, BaseMatcher
from vismatch.utils import resize_to_divisible, lower_config, add_to_path, pad_images_to_same_shape

BASE_PATH = THIRD_PARTY_DIR.joinpath("aspanformer")
add_to_path(BASE_PATH)

from src.ASpanFormer.aspanformer import ASpanFormer
from src.config.default import get_cfg_defaults as aspan_cfg_defaults


[docs] class AspanformerMatcher(BaseMatcher): hf_model_id = "vismatch/aspanformer" weights_src = "https://drive.google.com/file/d/1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k/view" divisible_size = 32 def __init__(self, device="cpu", **kwargs): super().__init__(device, **kwargs) cache_dir = Path(snapshot_download(self.hf_model_id)) self.weights_path = cache_dir / "weights" / "outdoor.ckpt" self.download_weights(cache_dir) add_to_path(BASE_PATH) # ensure aspanformer's ``src`` is resolvable config = aspan_cfg_defaults() config.merge_from_file(BASE_PATH.joinpath("configs", "aspan", "outdoor", "aspan_test.py")) self.matcher = ASpanFormer(config=lower_config(config)["aspan"]) self.matcher.load_state_dict( torch.load(self.weights_path, map_location=self.device, weights_only=True)["state_dict"], strict=False ) self.matcher = self.matcher.to(device).eval()
[docs] def download_weights(self, cache_dir): if not self.weights_path.is_file(): print("Downloading Aspanformer outdoor... (takes a while)") gdown.download( self.weights_src, output=str(cache_dir / "weights_aspanformer.tar"), fuzzy=True, ) tar = tarfile.open(cache_dir / "weights_aspanformer.tar") tar.extractall(cache_dir) tar.close()
[docs] def preprocess(self, img): _, h, w = img.shape orig_shape = h, w img = resize_to_divisible(img, self.divisible_size) return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1): img0, img0_orig_shape = self.preprocess(img0) img1, img1_orig_shape = self.preprocess(img1) H0, W0 = img0.shape[-2:] H1, W1 = img1.shape[-2:] img0, img1 = pad_images_to_same_shape(img0, img1) batch = {"image0": img0, "image1": img1} self.matcher(batch, online_resize=True) # online_resize prevents breaking at very high res mkpts0 = batch["mkpts0_f"] mkpts1 = batch["mkpts1_f"] mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) return mkpts0, mkpts1, None, None, None, None