import importlib
import logging
from pathlib import Path
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as tfm
from yacs.config import CfgNode as CN
import sys
logger = logging.getLogger(__name__)
logger.setLevel(31) # Avoid printing useless low-level logs
[docs]
def get_image_pairs_paths(inputs: list[Path] | Path) -> list[tuple[Path, Path]]:
"""process input to produce a list of image pairs paths
Args:
inputs (list[Path] | Path): input path, which could be one of:
(1) two image paths
(2) dir with two images
(3) dir with dirs with image pairs
(4) txt file with two image paths per line
Returns:
list[tuple[Path, Path]]: list of pairs of image paths
"""
if len(inputs) > 2:
raise ValueError(f"--input should be one or two paths, not {len(inputs)} paths like {inputs}")
if len(inputs) == 2:
# --input is two paths of images
if not inputs[0].is_file() or not inputs[1].is_file():
raise ValueError(f"If --input is two paths, it should be two images, not {inputs}")
return [inputs]
assert len(inputs) == 1
inputs = Path(inputs[0])
if not inputs.exists():
raise ValueError(f"{inputs} does not exist")
if inputs.is_file():
# --input is a file with pairs of images paths
with open(inputs) as file:
lines = file.read().splitlines()
pairs_of_paths = [line.strip().split(" ") for line in lines]
for pair in pairs_of_paths:
if len(pair) != 2:
raise ValueError(f"{pair} should be a pair of paths")
return [(Path(path0.strip()), Path(path1.strip())) for path0, path1 in pairs_of_paths]
elif inputs.is_dir():
inner_files = sorted(Path(inputs).glob("*"))
if len(inner_files) == 2 and inner_files[0].is_file() and inner_files[1].is_file():
# --input is a dir with a pair of images
return [inner_files]
else:
# --input is a dir of subdirs, where each subdir has a pair of images
pairs_of_paths = [list(pair_dir.glob("*")) for pair_dir in inner_files]
for pair in pairs_of_paths:
if len(pair) != 2:
raise ValueError(f"{pair} should be a pair of paths")
return pairs_of_paths
else:
print(f"Could not parse inputs: {inputs}")
[docs]
def to_numpy(x: torch.Tensor | np.ndarray | dict | list) -> np.ndarray:
"""convert item or container of items to numpy
Args:
x (torch.Tensor | np.ndarray | dict | list): input
Returns:
np.ndarray: numpy array of input
"""
if isinstance(x, list):
return np.array([to_numpy(i) for i in x])
if isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
if isinstance(x, np.ndarray):
return x
if x is None:
return
raise NotImplementedError(f"to_numpy not implemented for data type {type(x)}")
[docs]
def to_tensor(x: np.ndarray | torch.Tensor, device: str = None) -> torch.Tensor:
"""Convert to tensor and place on device
Args:
x (np.ndarray | torch.Tensor): item to convert to tensor
device (str, optional): device to place tensor on. Defaults to None.
Returns:
torch.Tensor: tensor with data from `x` on device `device`
"""
if isinstance(x, torch.Tensor):
pass
elif isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if device is not None:
return x.to(device)
else:
return x
[docs]
def to_device(data: torch.Tensor | dict | list, device: str = "cuda"):
"""Recursively move tensors in nested data structures to `device`."""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [to_device(item, device) for item in data]
else:
return data
[docs]
def to_normalized_coords(pts: np.ndarray | torch.Tensor, height: int, width: int):
"""normalize kpt coords from px space to [0,1]
Assumes pts are in x, y order in array/tensor shape (N, 2)
Args:
pts (np.ndarray | torch.Tensor): array of kpts, must be shape (N, 2)
height (int): height of img
width (int): width of img
Returns:
np.array: kpts in normalized [0,1] coords
"""
# normalize kpt coords from px space to [0,1]
# assume pts are in x,y order
assert pts.shape[-1] == 2, f"input to `to_normalized_coords` should be shape (N, 2), input is shape {pts.shape}"
pts = to_numpy(pts).astype(float)
pts[:, 0] /= width
pts[:, 1] /= height
return pts
[docs]
def to_px_coords(pts: np.ndarray | torch.Tensor, height: int, width: int) -> np.ndarray:
"""unnormalized kpt coords from [0,1] to px space
Assumes pts are in x, y order
Args:
pts (np.ndarray | torch.Tensor): array of kpts, must be shape (N, 2)
height (int): height of img
width (int): width of img
Returns:
np.array: kpts in normalized [0,1] coords
"""
assert pts.shape[-1] == 2, f"input to `to_px_coords` should be shape (N, 2), input is shape {pts.shape}"
pts = to_numpy(pts)
pts[:, 0] *= width
pts[:, 1] *= height
return pts
[docs]
def pad_images_to_same_shape(img0: torch.Tensor, img1: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad two image tensors to the same spatial dimensions (right/bottom zero-padding)."""
h0, w0 = img0.shape[-2:]
h1, w1 = img1.shape[-2:]
if h0 == h1 and w0 == w1:
return img0, img1
max_h = max(h0, h1)
max_w = max(w0, w1)
img0 = torch.nn.functional.pad(img0, (0, max_w - w0, 0, max_h - h0))
img1 = torch.nn.functional.pad(img1, (0, max_w - w1, 0, max_h - h1))
return img0, img1
[docs]
def resize_to_divisible(img: torch.Tensor, divisible_by: int = 14) -> torch.Tensor:
"""Resize to be divisible by a factor. Useful for ViT based models.
Args:
img (torch.Tensor): img as tensor, in (*, H, W) order
divisible_by (int, optional): factor to make sure img is divisible by. Defaults to 14.
Returns:
torch.Tensor: img tensor with divisible shape
"""
h, w = img.shape[-2:]
divisible_h = max(divisible_by, round(h / divisible_by) * divisible_by)
divisible_w = max(divisible_by, round(w / divisible_by) * divisible_by)
img = tfm.functional.resize(img, [divisible_h, divisible_w], antialias=True)
return img
[docs]
def lower_config(yacs_cfg: CN) -> dict:
"""Convert yacs config to lower-case dict recursively."""
if not isinstance(yacs_cfg, CN):
return yacs_cfg
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
[docs]
def load_module(module_name: str, module_path: Path | str) -> None:
"""Load module from `module_path` into the interpreter with the namespace given by module_name.
Note that `module_path` is usually the path to an `__init__.py` file.
Args:
module_name (str): module name (will be used to import from later, as in `from module_name import my_function`)
module_path (Path | str): path to module (usually an __init__.py file)
"""
# load gluefactory into namespace
# module_name = 'gluefactory'
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
_THIRD_PARTY_DIR = str(Path(__file__).resolve().parent / "third_party") + "/"
[docs]
def add_to_path(path: str | Path, **_kwargs) -> None:
"""Add *path* to the front of ``sys.path``, allowing imports from it.
Always inserts at position 0 so the most recently added directory wins.
Auto-detects every package and module in *path* and, if any of them are
already cached in ``sys.modules`` from a different vismatch third-party
directory, flushes the stale entries so the next import resolves correctly.
User code, stdlib, and pip packages are never touched.
"""
path = str(Path(path).resolve())
if path in sys.path:
sys.path.remove(path)
sys.path.insert(0, path)
# Auto-detect and flush stale modules from other third-party repos.
base = Path(path).resolve()
if not base.is_dir():
return
prefix = str(base) + "/"
for child in base.iterdir():
# Only consider regular Python packages (dir + __init__.py) and .py modules.
if child.is_dir() and child.joinpath("__init__.py").is_file():
name = child.name
elif child.is_file() and child.suffix == ".py" and child.name != "__init__.py":
name = child.stem
else:
continue
mod = sys.modules.get(name)
if mod is None:
continue
origin = getattr(mod, "__file__", None)
if not origin:
continue # built-in — leave it alone
resolved = str(Path(origin).resolve())
if resolved.startswith(prefix):
continue # already loaded from this directory
if not resolved.startswith(_THIRD_PARTY_DIR):
continue # loaded from user code / pip / stdlib — never touch it
# Stale module from a different vismatch/third-party repo — flush it
for k in [k for k in sys.modules if k == name or k.startswith(name + ".")]:
del sys.modules[k]
[docs]
def get_default_device() -> str:
"""get best available device for torch: cuda, mps (mac), else cpu
Returns:
str: best available device as str
"""
# default device is cpu
device = "cpu"
# test for mac device (darwin) and mps availability
if sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
# check cuda availability
elif torch.cuda.is_available():
device = "cuda"
return device
[docs]
def flow_to_matches(
flow: np.ndarray,
covisibility: np.ndarray,
num_samples: int = 1000,
min_confidence: float = 0.0,
method: str = "probabilistic",
rng: np.random.RandomState | np.random.Generator = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Convert a dense optical flow + covisibility map to sparse keypoint matches.
Args:
flow (np.ndarray): shape (2, H, W) or (H, W, 2). Interpreted as (dx, dy) per pixel.
covisibility (np.ndarray): shape (H, W) with confidence in [0, 1] (or any non-negative scores).
num_samples (int, optional): max number of matches to return. Defaults to 1000.
min_confidence (float, optional): ignore pixels with covisibility <= min_confidence. Defaults to 0.0.
method (str, optional): sampling method, one of "probabilistic", "topk", or "grid". Defaults to "probabilistic".
rng (np.random.RandomState | np.random.Generator, optional): for reproducibility. Defaults to None.
Returns:
tuple: (matches0, matches1, confidences) where:
- matches0 (np.ndarray): (N, 2) source keypoints as (x, y) (float32)
- matches1 (np.ndarray): (N, 2) target keypoints as (x, y) = source + flow (float32)
- confidences (np.ndarray): (N,) covisibility/confidence values (float32)
"""
if rng is None:
rng = np.random
# Normalize flow shape -> (2, H, W)
if flow.ndim == 3 and flow.shape[2] == 2:
flow_xy = flow.transpose(2, 0, 1) # (2, H, W)
elif flow.ndim == 3 and flow.shape[0] == 2:
flow_xy = flow
else:
raise ValueError("flow must have shape (2,H,W) or (H,W,2)")
H, W = covisibility.shape
if flow_xy.shape[1:] != (H, W):
raise ValueError(
f"flow and covisibility spatial dims mismatch: flow {flow_xy.shape[1:]}, covisibility {covisibility.shape}"
)
# Flatten grids
xs = np.arange(W, dtype=np.float32)
ys = np.arange(H, dtype=np.float32)
gx, gy = np.meshgrid(xs, ys) # gx: (H,W) x coords, gy: (H,W) y coords
flat_conf = covisibility.ravel().astype(np.float64)
valid_mask = flat_conf > min_confidence
if valid_mask.sum() == 0:
return np.zeros((0, 2), dtype=np.float32), np.zeros((0, 2), dtype=np.float32), np.zeros((0,), dtype=np.float32)
valid_idxs = np.nonzero(valid_mask)[0]
if method == "probabilistic":
scores = flat_conf[valid_mask].astype(np.float64)
# avoid degenerate all-zero
if scores.sum() <= 0:
probs = None
else:
probs = scores / scores.sum()
k = min(num_samples, len(valid_idxs))
# if probs is None or degenerate, fallback to uniform
chosen = rng.choice(valid_idxs, size=k, replace=False, p=probs)
elif method == "topk":
k = min(num_samples, len(valid_idxs))
topk_local = np.argsort(-flat_conf[valid_mask])[:k]
chosen = valid_idxs[topk_local]
elif method == "grid":
# choose roughly sqrt grid
n = max(1, int(np.sqrt(num_samples)))
xs_idx = np.linspace(0, W - 1, n, dtype=int)
ys_idx = np.linspace(0, H - 1, int(np.ceil(num_samples / n)), dtype=int)
gx_idx, gy_idx = np.meshgrid(xs_idx, ys_idx)
chosen_coords = np.stack([gy_idx.ravel(), gx_idx.ravel()], axis=1)
chosen_coords = chosen_coords[:num_samples]
chosen = chosen_coords[:, 0] * W + chosen_coords[:, 1]
# mask by min_confidence
keep = flat_conf[chosen] > min_confidence
chosen = chosen[keep]
else:
raise ValueError("method must be one of 'probabilistic','topk','grid'")
# gather coordinates and flows
gy_flat = gy.ravel().astype(np.float32)
gx_flat = gx.ravel().astype(np.float32)
dx_flat = flow_xy[0].ravel().astype(np.float32)
dy_flat = flow_xy[1].ravel().astype(np.float32)
src_x = gx_flat[chosen]
src_y = gy_flat[chosen]
dx = dx_flat[chosen]
dy = dy_flat[chosen]
confs = flat_conf[chosen].astype(np.float32)
matches0 = np.stack([src_x, src_y], axis=1).astype(np.float32)
matches1 = (matches0 + np.stack([dx, dy], axis=1)).astype(np.float32)
return matches0, matches1, confs
def _load_image(path: str | Path, resize: int | tuple = None, rot_angle: float = 0) -> torch.Tensor:
"""load image from filesystem and return as tensor. Optionally rotate and resize.
Args:
path (str | Path): path to image on filesystem
resize (int | tuple, optional): size to resize img, either single value for square resize or tuple of (H, W). Defaults to None.
rot_angle (float, optional): CCW rotation angle in degrees. Defaults to 0.
Returns:
torch.Tensor: image as tensor (C x H x W)
"""
if isinstance(resize, int):
resize = (resize, resize)
img = tfm.ToTensor()(Image.open(path).convert("RGB"))
if resize is not None:
img = tfm.Resize(resize, antialias=True)(img)
img = tfm.functional.rotate(img, rot_angle)
return img
[docs]
def to_tensor_image(img):
if isinstance(img, (str, Path)):
img = _load_image(img)
elif isinstance(img, Image.Image):
img = tfm.ToTensor()(img.convert("RGB"))
elif isinstance(img, np.ndarray):
img = torch.from_numpy(img)
assert isinstance(img, torch.Tensor), "img should be a torch.Tensor, a path, or a PIL Image"
assert img.ndim == 3 and img.shape[0] == 3, f"img should have shape (3, H, W), got {img.shape}"
# Small tolerance of 0.2 because images after bicubic resizing can slightly exceed the [0, 1] range
# This is expected, not a bug, see https://github.com/opencv/opencv/issues/7195
assert -0.2 <= img.min() and img.max() <= 1.2, f"img should be in [0, 1], got [{img.min()}, {img.max()}]"
return img