Source code for vismatch.im_models.omniglue

import tensorflow  # noqa: F401 -- must import before torch to avoid segfault (github.com/tensorflow/tensorflow/issues/14812)
import py3_wget
import tarfile
import zipfile
from pathlib import Path
from kornia import tensor_to_image
import torch
import numpy as np
from skimage.util import img_as_ubyte
from huggingface_hub import snapshot_download

from vismatch import BaseMatcher, THIRD_PARTY_DIR
from vismatch.utils import add_to_path


BASE_PATH = THIRD_PARTY_DIR.joinpath("omniglue")
OMNI_SRC_PATH = BASE_PATH.joinpath("src")
OMNI_THIRD_PARTY_PATH = BASE_PATH

add_to_path(OMNI_SRC_PATH)
add_to_path(OMNI_THIRD_PARTY_PATH)  # allow access to dinov2
import omniglue


[docs] class OmniglueMatcher(BaseMatcher): hf_model_id = "vismatch/omniglue" def __init__(self, device="cpu", conf_thresh=0.02, **kwargs): super().__init__(device, **kwargs) cache_dir = Path(snapshot_download(self.hf_model_id)) self.OG_WEIGHTS_PATH = cache_dir / "og_export" self.SP_WEIGHTS_PATH = cache_dir / "sp_v6" self.DINOv2_PATH = cache_dir / "dinov2_vitb14_pretrain.pth" self.download_weights(cache_dir) self.model = omniglue.OmniGlue( og_export=str(self.OG_WEIGHTS_PATH), sp_export=str(self.SP_WEIGHTS_PATH), dino_export=str(self.DINOv2_PATH), ) self.conf_thresh = conf_thresh
[docs] def download_weights(self, cache_dir): if not self.OG_WEIGHTS_PATH.exists(): print("Downloading omniglue matcher weights...") py3_wget.download_file( "https://storage.googleapis.com/omniglue/og_export.zip", self.OG_WEIGHTS_PATH.with_suffix(".zip"), ) with zipfile.ZipFile(self.OG_WEIGHTS_PATH.with_suffix(".zip")) as zip_f: zip_f.extractall(path=cache_dir) if not self.SP_WEIGHTS_PATH.exists(): print("Downloading omniglue superpoint weights...") py3_wget.download_file( "https://github.com/rpautrat/SuperPoint/raw/master/pretrained_models/sp_v6.tgz", self.SP_WEIGHTS_PATH.with_suffix(".tgz"), ) tar = tarfile.open(self.SP_WEIGHTS_PATH.with_suffix(".tgz")) tar.extractall(path=cache_dir) tar.close() if not self.DINOv2_PATH.exists(): print("Downloading omniglue DINOv2 weights...") py3_wget.download_file( "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", self.DINOv2_PATH, )
[docs] def preprocess(self, img): if isinstance(img, torch.Tensor): img = tensor_to_image(img) assert isinstance(img, np.ndarray) return img_as_ubyte(np.clip(img, 0, 1))
def _forward(self, img0, img1): img0 = self.preprocess(img0) img1 = self.preprocess(img1) mkpts0, mkpts1, match_conf = self.model.FindMatches(img0, img1) if self.conf_thresh is not None: keep_idx = [] for i in range(mkpts0.shape[0]): if match_conf[i] > self.conf_thresh: keep_idx.append(i) mkpts0 = mkpts0[keep_idx] mkpts1 = mkpts1[keep_idx] match_conf = match_conf[keep_idx] return mkpts0, mkpts1, None, None, None, None