Source code for vismatch.im_models.gim

import torch
from safetensors.torch import load_file

from kornia.color import rgb_to_grayscale

from huggingface_hub import snapshot_download
from vismatch import THIRD_PARTY_DIR, BaseMatcher
from vismatch.utils import load_module, add_to_path

BASE_PATH = THIRD_PARTY_DIR.joinpath("gim")
add_to_path(BASE_PATH)
from dkm.models.model_zoo.DKMv3 import DKMv3


[docs] class GIM_DKM(BaseMatcher): def __init__(self, device="cpu", max_num_keypoints=5000, **kwargs): super().__init__(device, **kwargs) self.model = DKMv3(weights=None, h=672, w=896) self.max_num_keypoints = max_num_keypoints self.load_weights() self.model = self.model.eval().to(device)
[docs] def load_weights(self): weights_path = f"{snapshot_download('vismatch/gim-dkm')}/gim_dkm_100h.safetensors" state_dict = load_file(weights_path) for k in list(state_dict.keys()): if k.startswith("model."): state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) if "encoder.net.fc" in k: state_dict.pop(k) self.model.load_state_dict(state_dict)
[docs] def preprocess(self, img): # this version of DKM requires PIL images as input # return Image.fromarray(np.uint8(255*tensor_to_image(img))) if img.ndim < 4: img = img.unsqueeze(0) return img
def _forward(self, img0, img1): height0, width0 = img0.shape[-2:] height1, width1 = img1.shape[-2:] img0 = self.preprocess(img0) # now as PIL img img1 = self.preprocess(img1) # now as PIL img dense_matches, dense_certainty = self.model.match(img0, img1, device=self.device) torch.cuda.empty_cache() # sample matching keypoints from dense warp sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, self.max_num_keypoints) torch.cuda.empty_cache() mkpts0 = sparse_matches[:, :2] mkpts1 = sparse_matches[:, 2:] # convert to px coords mkpts0 = torch.stack( (width0 * (mkpts0[:, 0] + 1) / 2, height0 * (mkpts0[:, 1] + 1) / 2), dim=-1, ) mkpts1 = torch.stack( (width1 * (mkpts1[:, 0] + 1) / 2, height1 * (mkpts1[:, 1] + 1) / 2), dim=-1, ) # b_ids = torch.where(mconf[None])[0] return mkpts0, mkpts1, None, None, None, None
[docs] class GIM_LightGlue(BaseMatcher): def __init__(self, device="cpu", max_keypoints=2048, **kwargs): super().__init__(device, **kwargs) # load the altered version of gluefactory load_module("gluefactory_gim", BASE_PATH.joinpath("gluefactory/__init__.py")) from gluefactory_gim.superpoint import SuperPoint from gluefactory_gim.models.matchers.lightglue import LightGlue # Download weights from HF and get paths from cache repo = snapshot_download("vismatch/gim-lightglue") self.ckpt_path = f"{repo}/gim_lightglue_100h.ckpt" self.superpoint_v1_path = f"{repo}/superpoint_v1.pth" self.detector = SuperPoint( { "max_num_keypoints": max_keypoints, "force_num_keypoints": True, "detection_threshold": 0.0, "nms_radius": 3, "trainable": False, "weights_path": self.superpoint_v1_path, } ) self.model = LightGlue( { "filter_threshold": 0.1, "flash": False, "checkpointed": True, } ) self.load_weights()
[docs] def load_weights(self): state_dict = torch.load(self.ckpt_path, map_location="cpu", weights_only=True) if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] for k in list(state_dict.keys()): if k.startswith("model."): state_dict.pop(k) if k.startswith("superpoint."): state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k) self.detector.load_state_dict(state_dict) state_dict = torch.load(self.ckpt_path, map_location="cpu", weights_only=True) if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] for k in list(state_dict.keys()): if k.startswith("superpoint."): state_dict.pop(k) if k.startswith("model."): state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) self.model.load_state_dict(state_dict) self.detector = self.detector.eval().to(self.device) self.model = self.model.eval().to(self.device)
[docs] def preprocess(self, img): # convert to grayscale return rgb_to_grayscale(img.unsqueeze(0))
def _forward(self, img0, img1): img0 = self.preprocess(img0) img1 = self.preprocess(img1) data = dict(image0=img0, image1=img1) scale0 = torch.tensor([1.0, 1.0]).to(self.device)[None] scale1 = torch.tensor([1.0, 1.0]).to(self.device)[None] size0 = torch.tensor(data["image0"].shape[-2:][::-1])[None] size1 = torch.tensor(data["image1"].shape[-2:][::-1])[None] data.update(dict(size0=size0, size1=size1)) data.update(dict(scale0=scale0, scale1=scale1)) pred = {} pred.update( { k + "0": v for k, v in self.detector( { "image": data["image0"], "image_size": data["size0"], } ).items() } ) pred.update( { k + "1": v for k, v in self.detector( { "image": data["image1"], "image_size": data["size1"], } ).items() } ) pred.update(self.model({**pred, **data, **{"resize0": data["size0"], "resize1": data["size1"]}})) kpts0 = torch.cat([kp * s for kp, s in zip(pred["keypoints0"], data["scale0"][:, None])]) kpts1 = torch.cat([kp * s for kp, s in zip(pred["keypoints1"], data["scale1"][:, None])]) desc0, desc1 = pred["descriptors0"], pred["descriptors1"] m_bids = torch.nonzero(pred["keypoints0"].sum(dim=2) > -1)[:, 0] matches = pred["matches"] bs = data["image0"].size(0) mkpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) mkpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) # b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) # mconf = torch.cat(pred['scores']) return mkpts0, mkpts1, kpts0, kpts1, desc0[0], desc1[0]