"""Visualization utilities for image matching.
Adapted from LightGlue's viz2d: https://github.com/cvg/LightGlue
"""
import sys
import cv2
import matplotlib
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from kornia.utils import tensor_to_image
from vismatch.utils import to_numpy, to_tensor_image
from pathlib import Path
if not hasattr(sys, "ps1"):
matplotlib.use("Agg")
[docs]
def plot_images(
imgs: list[torch.Tensor | np.ndarray | str | Path | Image.Image],
titles=None,
cmaps="gray",
dpi=100,
pad=0.5,
adaptive=True,
) -> np.ndarray[matplotlib.axes.Axes]:
"""Plot a set of images horizontally."""
imgs = [to_tensor_image(img) for img in imgs]
imgs = [np.clip(tensor_to_image(img), 0, 1) for img in imgs]
num_imgs = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * num_imgs
ratios = [img.shape[1] / img.shape[0] for img in imgs] if adaptive else [4 / 3] * num_imgs
fig, axs = plt.subplots(
1, num_imgs, figsize=[sum(ratios) * 4.5, 4.5], dpi=dpi, gridspec_kw={"width_ratios": ratios}
)
if num_imgs == 1:
axs = np.array([axs])
for idx in range(num_imgs):
axs[idx].imshow(imgs[idx], cmap=plt.get_cmap(cmaps[idx]))
axs[idx].set_axis_off()
if titles:
axs[idx].set_title(titles[idx])
fig.tight_layout(pad=pad)
return axs
def _draw_kpts(
kpts: list[np.ndarray | torch.Tensor], axs: list[matplotlib.axes.Axes], colors: str = "lime", point_size: int = 4
) -> list[matplotlib.axes.Axes]:
"""Plot keypoints on axes."""
assert len(kpts) == len(axs), "Number of keypoints sets must match number of axes."
if not isinstance(colors, list):
colors = [colors] * len(kpts)
if axs is None:
axs = plt.gcf().axes
for ax, kpts, color in zip(np.array(axs).flatten(), kpts, colors):
kpts = to_numpy(kpts)
ax.scatter(kpts[:, 0], kpts[:, 1], c=color, s=point_size, linewidths=0)
return axs
[docs]
def add_text(
ax: matplotlib.axes.Axes,
text: str,
pos: tuple[float, float] = (0.01, 0.99),
fs: int = 15,
color="w",
outline_color="k",
outline_width=2,
va="top",
) -> matplotlib.axes.Axes:
"""Add text with outline to an image axis."""
text = ax.text(*pos, text, fontsize=fs, ha="left", va=va, color=color, transform=ax.transAxes)
if outline_color is not None:
text.set_path_effects(
[path_effects.Stroke(linewidth=outline_width, foreground=outline_color), path_effects.Normal()]
)
return ax
[docs]
def save_plot(fig=None, path: str | Path = None, **kw) -> Path:
"""Save the current figure without any white margin."""
if fig is None:
fig = plt.gcf()
fig.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
return Path(path).resolve()
def _draw_matches(
kpts0: np.ndarray | torch.Tensor,
kpts1: np.ndarray | torch.Tensor,
fig: matplotlib.figure.Figure,
color: str = "lime",
lw: float = 0.2,
point_size: int = 4,
):
"""Draw match lines between keypoints on figure."""
kpts0, kpts1 = to_numpy(kpts0), to_numpy(kpts1)
if len(kpts0) == 0:
return
if fig is None:
fig = plt.gcf()
ax0, ax1 = fig.axes[0], fig.axes[1]
colors = [color] * len(kpts0)
for idx in range(len(kpts0)):
line = matplotlib.patches.ConnectionPatch(
xyA=(kpts0[idx, 0], kpts0[idx, 1]),
xyB=(kpts1[idx, 0], kpts1[idx, 1]),
coordsA=ax0.transData,
coordsB=ax1.transData,
axesA=ax0,
axesB=ax1,
color=colors[idx],
linewidth=lw,
)
line.set_annotation_clip(True)
fig.add_artist(line)
ax0.autoscale(enable=False)
ax1.autoscale(enable=False)
# plot points on the respective axes
if point_size > 0:
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=colors, s=point_size)
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=colors, s=point_size)
return fig
[docs]
def plot_matches(
img0: torch.Tensor,
img1: torch.Tensor,
result: dict,
show_matched_kpts: bool = True,
show_all_kpts: bool = False,
save_path: str | Path | None = None,
color: str = "lime",
lw: float = 0.2,
point_size: int = 4,
show_text: bool = True,
):
"""Plot matches between two images."""
axs = plot_images([img0, img1])
fig = axs[0].get_figure()
# draw all matches (even non-inliers) in blue
if show_matched_kpts and "matched_kpts0" in result:
_draw_matches(result["matched_kpts0"], result["matched_kpts1"], fig, "blue", lw * 0.25, point_size * 0.5)
# draw all keypoints in orange
if show_all_kpts and result.get("all_kpts0") is not None:
_draw_kpts([result["all_kpts0"], result["all_kpts1"]], axs, colors="orange", point_size=point_size * 0.5)
_draw_matches(result["inlier_kpts0"], result["inlier_kpts1"], fig, color, lw, point_size)
if show_text:
num_inliers, num_matches = len(result["inlier_kpts0"]), len(result["matched_kpts1"])
ratio = f"{num_inliers / num_matches:.2f}" if num_matches else "N/A"
add_text(
axs[0], f"{num_inliers} inliers / {num_matches} matches\ninlier ratio: {ratio}", fs=17, outline_width=2
)
add_text(axs[0], "Img0", pos=(0.01, 0.01), va="bottom")
add_text(axs[1], "Img1", pos=(0.01, 0.01), va="bottom")
if save_path is not None:
save_plot(fig, save_path)
return axs
[docs]
def plot_keypoints(
img0: torch.Tensor, result: dict, model_name: str = "", color="orange", save_path: str | Path | None = None
) -> matplotlib.axes.Axes:
"""Plot keypoints in one image."""
ax = plot_images([img0])[0]
_draw_kpts([result["all_kpts0"]], [ax], colors=color, point_size=10)
label = f"{len(result['all_kpts0'])} kpts" + (f" - {model_name}" if model_name else "")
add_text(ax, label, fs=20)
if save_path is not None:
fig = ax.get_figure()
save_plot(fig, save_path)
return ax
[docs]
def stitch(img0: torch.Tensor | np.ndarray, img1: torch.Tensor | np.ndarray, result) -> np.ndarray:
"""Stitch two images together using homography."""
if isinstance(img0, torch.Tensor):
img0 = tensor_to_image(img0)
if isinstance(img1, torch.Tensor):
img1 = tensor_to_image(img1)
if img0.shape[2] == 3:
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2BGRA)
if img1.shape[2] == 3:
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2BGRA)
h0, w0 = img0.shape[:2]
h1, w1 = img1.shape[:2]
corners0 = np.float32([[0, 0], [0, h0], [w0, h0], [w0, 0]]).reshape(-1, 1, 2)
corners1 = np.float32([[0, 0], [0, h1], [w1, h1], [w1, 0]]).reshape(-1, 1, 2)
warped_corners0 = cv2.perspectiveTransform(corners0, result["H"])
all_corners = np.concatenate((warped_corners0, corners1), axis=0)
[x_min, y_min] = np.int32(all_corners.min(axis=0).ravel() - 0.5)
[x_max, y_max] = np.int32(all_corners.max(axis=0).ravel() + 0.5)
translation = [-x_min, -y_min]
H_translation = np.array([[1, 0, translation[0]], [0, 1, translation[1]], [0, 0, 1]])
stitched = cv2.warpPerspective(img0, H_translation.dot(result["H"]), (x_max - x_min, y_max - y_min))
stitched[translation[1] : translation[1] + h1, translation[0] : translation[0] + w1] = img1
return stitched