# -*- coding: utf-8 -*-
"""
PyTorch functions for 2D/3D affine image registration, volume warping, and point cloud alignment.
This module provides differentiable implementations for aligning 2D images and 3D volumes
using geometric transformations. It includes tools for both intensity-based image registration
and geometry-based point cloud registration (ICP).
Main features:
- Differentiable 2D and 3D registration (`register_affine_2d` and `register_affine_3d`) supporting Rotation, Translation, Scale, and Shear.
- Batched volume warping (`warp_volume_xy_batched`) to apply 2D transforms to 3D stacks efficiently.
- Point cloud alignment using a differentiable Iterative Closest Point (`icp_torch`) implementation.
- Utilities for handling pixel-space vs. normalized-space affine matrices.
Author: Antony Vamvakeros
"""
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
# ==============================================================================
# HELPER UTILITIES
# ==============================================================================
[docs]
def to_tensor2d(x, device=None, dtype=torch.float32):
"""
Converts a 2D NumPy array or Tensor into a 4D Torch Tensor (1, 1, H, W).
"""
if not torch.is_tensor(x):
x = torch.from_numpy(x)
if x.dim() == 2: # H,W
x = x.unsqueeze(0).unsqueeze(0) # 1,1,H,W
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return x.to(device=device, dtype=dtype)
[docs]
def to_tensor3d(x, device=None, dtype=torch.float32):
"""
Converts a 3D NumPy array or Tensor into a 5D Torch Tensor (1, 1, D, H, W).
"""
if not torch.is_tensor(x):
x = torch.from_numpy(x)
# Check dimensions
if x.dim() == 3: # D, H, W
x = x.unsqueeze(0).unsqueeze(0) # 1, 1, D, H, W
elif x.dim() == 4: # C, D, H, W (e.g. grayscale channel included)
x = x.unsqueeze(0)
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return x.to(device=device, dtype=dtype)
[docs]
def normalize_affine_matrix(matrix_pixel, height, width):
"""
Converts a 3x3 affine matrix defined in pixel coordinates (e.g., from skimage or pystackreg)
into the normalized coordinate system [-1, 1] required by PyTorch.
Parameters
----------
matrix_pixel : numpy.ndarray
3x3 affine matrix in pixel units.
height, width : int
Dimensions of the image.
Returns
-------
torch.Tensor
(1, 2, 3) tensor ready for F.affine_grid.
"""
norm_mat = np.array([
[2 / width, 0, -1],
[0, 2 / height, -1],
[0, 0, 1]
])
# T_norm = N * T_pix * N^-1
tm_norm = norm_mat @ matrix_pixel @ np.linalg.inv(norm_mat)
# Extract top 2 rows
return torch.tensor(tm_norm[:2, :], dtype=torch.float32).unsqueeze(0)
# ==============================================================================
# LOSS FUNCTIONS
# ==============================================================================
[docs]
class NCC(torch.nn.Module):
"""
Zero-Normalized Cross Correlation (robust to brightness changes).
"""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
[docs]
def forward(self, x, y):
xm, ym = x.mean(), y.mean()
xv, yv = x - xm, y - ym
num = (xv * yv).sum()
den = torch.sqrt((xv * xv).sum() * (yv * yv).sum() + self.eps)
return 1 - num / (den + self.eps)
[docs]
def mae_loss(x, y):
"""Mean Absolute Error"""
return torch.mean(torch.abs(x - y))
[docs]
def mse_loss(x, y):
"""Mean Squared Error"""
return torch.mean((x - y) ** 2)
# ==============================================================================
# IMAGE REGISTRATION (2D)
# ==============================================================================
[docs]
def build_affine_matrix(theta, tx, ty, sx, sy, shx, shy, device, dtype):
"""
Constructs the 3x3 Inverse Affine Matrix (normalized coordinates) combining
rotation, translation, scale, and shear.
"""
# 1. Translation
T = torch.eye(3, device=device, dtype=dtype)
T[0, 2] = tx
T[1, 2] = ty
# 2. Rotation
c, s = torch.cos(theta), torch.sin(theta)
R = torch.eye(3, device=device, dtype=dtype)
R[0, 0], R[0, 1] = c, -s
R[1, 0], R[1, 1] = s, c
# 3. Shear
Sh = torch.eye(3, device=device, dtype=dtype)
Sh[0, 1] = shx
Sh[1, 0] = shy
# 4. Scale
S = torch.eye(3, device=device, dtype=dtype)
S[0, 0] = sx
S[1, 1] = sy
# Combine: M_forward = T @ R @ Sh @ S
M_fwd = T @ R @ Sh @ S
# Invert for affine_grid (Output -> Input mapping)
M_inv = torch.linalg.inv(M_fwd)
return M_inv[:2, :].unsqueeze(0)
[docs]
def register_affine_2d(ref, mov,
order=['rot', 'trans', 'scale', 'shear'],
loss_type='ncc',
iters=200, lr=1e-2, verbose=False, device=None):
"""
Register two 2D images using optimization of affine parameters.
Parameters
----------
ref, mov : (H,W) numpy array or torch tensor
order : list of str
Parameters to optimize: 'rot', 'trans', 'scale', 'shear'.
loss_type : str
Objective function: 'ncc', 'mae', or 'mse'.
iters : int
Number of iterations.
lr : float
Learning rate.
Returns
-------
warped_image : (H,W) numpy array
params : dict
Optimized parameters (theta, tx, ty, sx, sy, etc.)
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ref = to_tensor2d(ref, device=device)
mov = to_tensor2d(mov, device=device, dtype=ref.dtype)
# --- Initialize Parameters ---
theta = torch.tensor(0.0, device=device, dtype=ref.dtype)
if 'rot' in order: theta.requires_grad_(True)
tx = torch.tensor(0.0, device=device, dtype=ref.dtype)
ty = torch.tensor(0.0, device=device, dtype=ref.dtype)
if 'trans' in order:
tx.requires_grad_(True); ty.requires_grad_(True)
sx = torch.tensor(1.0, device=device, dtype=ref.dtype)
sy = torch.tensor(1.0, device=device, dtype=ref.dtype)
if 'scale' in order:
sx.requires_grad_(True); sy.requires_grad_(True)
shx = torch.tensor(0.0, device=device, dtype=ref.dtype)
shy = torch.tensor(0.0, device=device, dtype=ref.dtype)
if 'shear' in order:
shx.requires_grad_(True); shy.requires_grad_(True)
params = [p for p in [theta, tx, ty, sx, sy, shx, shy] if p.requires_grad]
if not params:
return mov.squeeze().cpu().numpy(), {}
opt = torch.optim.Adam(params, lr=lr)
# Select Loss
if loss_type == 'ncc':
loss_fn = NCC()
elif loss_type == 'mae':
loss_fn = mae_loss
elif loss_type == 'mse':
loss_fn = mse_loss
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# --- Optimization Loop ---
for i in tqdm(range(iters)):
opt.zero_grad()
A = build_affine_matrix(theta, tx, ty, sx, sy, shx, shy, device, ref.dtype)
# align_corners=False preserves rotation center better
grid = F.affine_grid(A, size=ref.shape, align_corners=False)
warped = F.grid_sample(mov, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
L = loss_fn(warped, ref)
L.backward()
opt.step()
if verbose and (i % max(1, iters // 5) == 0 or i == iters-1):
print(f"[{i+1}/{iters}] Loss={L.item():.5f}")
return warped[0,0].detach().cpu().numpy(), {
"theta": float(theta.detach().cpu()),
"tx": float(tx.detach().cpu()),
"ty": float(ty.detach().cpu()),
"sx": float(sx.detach().cpu()),
"sy": float(sy.detach().cpu()),
"shx": float(shx.detach().cpu()),
"shy": float(shy.detach().cpu()),
}
# ==============================================================================
# VOLUME OPERATIONS (Batched)
# ==============================================================================
[docs]
def warp_volume_xy_batched(volume_np, affine_matrix, is_pixel_space=False, batch_size=32, device=None):
"""
Apply a single 2D affine transform to every Z-slice of a 3D volume using batches.
Parameters
----------
volume_np : numpy.ndarray
Input 3D volume (Z, H, W).
affine_matrix : numpy.ndarray
3x3 Affine matrix.
is_pixel_space : bool, optional
If True, assumes `affine_matrix` is in pixel coordinates (e.g. from Skimage)
and normalizes it. If False, assumes it is already PyTorch-compatible [-1, 1].
batch_size : int
Number of slices to process simultaneously.
device : str
'cuda' or 'cpu'.
Returns
-------
numpy.ndarray
Transformed 3D volume (Z, H, W).
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Z, H, W = volume_np.shape
volume_out = np.zeros_like(volume_np, dtype=np.float32)
# Prepare Matrix
if is_pixel_space:
A = normalize_affine_matrix(affine_matrix, H, W).to(device)
else:
# Assuming input is already a 2x3 or 3x3 normalized matrix
if affine_matrix.shape == (3,3):
A = torch.from_numpy(affine_matrix[:2, :].astype(np.float32)).unsqueeze(0).to(device)
else:
A = torch.from_numpy(affine_matrix.astype(np.float32)).to(device)
if A.dim() == 2: A = A.unsqueeze(0)
# Processing Loop
for i in range(0, Z, batch_size):
bZ = min(batch_size, Z - i)
batch = volume_np[i:i + bZ].astype(np.float32)
# Shape: [Batch, Channel=1, H, W]
batch_tensor = torch.from_numpy(batch).unsqueeze(1).to(device)
# Expand matrix to match batch size: [Batch, 2, 3]
A_batch = A.repeat(bZ, 1, 1)
grid = F.affine_grid(A_batch, batch_tensor.shape, align_corners=False)
warped = F.grid_sample(batch_tensor, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
volume_out[i:i + bZ] = warped.squeeze(1).cpu().numpy()
return volume_out
# ==============================================================================
# POINT CLOUD REGISTRATION (ICP)
# ==============================================================================
[docs]
def icp_torch(A, B, max_iterations=2000, lr=0.01, tolerance=1e-6, verbose=False):
"""
Perform rigid Iterative Closest Point (ICP) registration from point cloud B to A.
Optimizes rotation (R) and translation (t) such that B_aligned = B @ R.T + t fits A.
Parameters
----------
A, B : torch.Tensor or numpy.ndarray
Point clouds of shape (N, 3).
max_iterations : int
Maximum optimization steps.
lr : float
Learning rate.
tolerance : float
Convergence threshold.
Returns
-------
B_aligned : numpy.ndarray
Aligned point cloud.
R : numpy.ndarray
3x3 Rotation matrix.
t : numpy.ndarray
Translation vector.
"""
# Convert to tensor if needed
if not torch.is_tensor(A): A = torch.from_numpy(A).float()
if not torch.is_tensor(B): B = torch.from_numpy(B).float()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
A, B = A.to(device), B.to(device)
# Initialize Rigid Parameters
R_eye = torch.eye(3, device=device)
t_vec = torch.zeros(3, device=device)
R = torch.eye(3, device=device, requires_grad=True)
t = torch.zeros(3, device=device, requires_grad=True)
optimizer = torch.optim.Adam([R, t], lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, factor=0.5)
prev_loss = float('inf')
for i in range(max_iterations):
optimizer.zero_grad()
# Apply transform - Note: Usually points are row vectors (N,3), so B @ R.T + t is correct
B_transformed = B @ R.T + t
# Loss: For ICP we need nearest neighbors.
# If A and B correspond point-to-point (same order), we use MSE.
# If they are unorganized, we assume correspondence for this snippet.
distances = torch.norm(B_transformed - A, dim=1)
loss = distances.mean()
loss.backward()
optimizer.step()
scheduler.step(loss)
if abs(prev_loss - loss.item()) < tolerance:
if verbose: print(f"Converged at iter {i}")
break
prev_loss = loss.item()
# Orthogonalize R post-optimization (Procrustes projection) to ensure it's a valid rotation
with torch.no_grad():
u, s, v = torch.svd(R)
R_ortho = u @ v.T
B_final = (B @ R_ortho.T + t).cpu().numpy()
return B_final, R_ortho.detach().cpu().numpy(), t.detach().cpu().numpy()
# ==============================================================================
# VOLUME REGISTRATION (3D)
# ==============================================================================
[docs]
def build_affine_matrix_3d(rx, ry, rz, tx, ty, tz, sx, sy, sz, device, dtype):
"""
Constructs the 3x4 Inverse Affine Matrix for 3D volumetric transformations.
Combines Translation, Rotation (Euler angles XYZ), and Scaling.
Shear is omitted for simplicity but can be added if needed.
Parameters
----------
rx, ry, rz : torch.Tensor
Rotation angles around X, Y, and Z axes (in radians).
tx, ty, tz : torch.Tensor
Translation in x, y, z (normalized coordinates [-1, 1]).
sx, sy, sz : torch.Tensor
Scale factors for x, y, z.
Returns
-------
torch.Tensor
(1, 3, 4) tensor representing the top three rows of the 4x4 inverse
affine matrix, suitable for 3D `F.affine_grid`.
"""
# 1. Translation (4x4)
T = torch.eye(4, device=device, dtype=dtype)
T[0, 3] = tx
T[1, 3] = ty
T[2, 3] = tz
# 2. Scaling (4x4)
S = torch.eye(4, device=device, dtype=dtype)
S[0, 0] = sx
S[1, 1] = sy
S[2, 2] = sz
# 3. Rotations (4x4)
# Rotation around X
Rx = torch.eye(4, device=device, dtype=dtype)
cx, sx_ = torch.cos(rx), torch.sin(rx)
Rx[1, 1], Rx[1, 2] = cx, -sx_
Rx[2, 1], Rx[2, 2] = sx_, cx
# Rotation around Y
Ry = torch.eye(4, device=device, dtype=dtype)
cy, sy_ = torch.cos(ry), torch.sin(ry)
Ry[0, 0], Ry[0, 2] = cy, sy_
Ry[2, 0], Ry[2, 2] = -sy_, cy
# Rotation around Z
Rz = torch.eye(4, device=device, dtype=dtype)
cz, sz_ = torch.cos(rz), torch.sin(rz)
Rz[0, 0], Rz[0, 1] = cz, -sz_
Rz[1, 0], Rz[1, 1] = sz_, cz
# Combine Rotations: R = Rz * Ry * Rx (standard Euler order)
R = Rz @ Ry @ Rx
# Full Forward Matrix: M = T * R * S
M_fwd = T @ R @ S
# Invert for affine_grid (Output -> Input mapping)
M_inv = torch.linalg.inv(M_fwd)
# Return top 3 rows (3x4 matrix) for 3D affine_grid
return M_inv[:3, :].unsqueeze(0)
[docs]
def register_affine_3d(ref, mov,
order=['rot', 'trans', 'scale'],
loss_type='ncc',
iters=200, lr=1e-2, verbose=False, device=None):
"""
Register two 3D volumes using optimization of affine parameters.
Parameters
----------
ref, mov : (D,H,W) numpy array or torch tensor
Input volumes.
order : list of str
Parameters to optimize: 'rot', 'trans', 'scale'.
loss_type : str
Objective function: 'ncc', 'mae', or 'mse'.
iters : int
Number of iterations.
lr : float
Learning rate.
Returns
-------
warped_vol : (D,H,W) numpy array
The registered moving volume.
params : dict
Optimized parameters (rx, ry, rz, tx, ty, tz, sx, sy, sz).
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ref = to_tensor3d(ref, device=device)
mov = to_tensor3d(mov, device=device, dtype=ref.dtype)
# --- Initialize Parameters ---
# Rotations (radians)
rx = torch.tensor(0.0, device=device, dtype=ref.dtype)
ry = torch.tensor(0.0, device=device, dtype=ref.dtype)
rz = torch.tensor(0.0, device=device, dtype=ref.dtype)
if 'rot' in order:
rx.requires_grad_(True)
ry.requires_grad_(True)
rz.requires_grad_(True)
# Translations (x, y, z)
tx = torch.tensor(0.0, device=device, dtype=ref.dtype)
ty = torch.tensor(0.0, device=device, dtype=ref.dtype)
tz = torch.tensor(0.0, device=device, dtype=ref.dtype)
if 'trans' in order:
tx.requires_grad_(True)
ty.requires_grad_(True)
tz.requires_grad_(True)
# Scales (sx, sy, sz)
sx = torch.tensor(1.0, device=device, dtype=ref.dtype)
sy = torch.tensor(1.0, device=device, dtype=ref.dtype)
sz = torch.tensor(1.0, device=device, dtype=ref.dtype)
if 'scale' in order:
sx.requires_grad_(True)
sy.requires_grad_(True)
sz.requires_grad_(True)
# Collect active parameters
params_list = [p for p in [rx, ry, rz, tx, ty, tz, sx, sy, sz] if p.requires_grad]
if not params_list:
print("Warning: No parameters selected for optimization.")
return mov.squeeze().cpu().numpy(), {}
opt = torch.optim.Adam(params_list, lr=lr)
# Select Loss
if loss_type == 'ncc':
loss_fn = NCC()
elif loss_type == 'mae':
loss_fn = mae_loss
elif loss_type == 'mse':
loss_fn = mse_loss
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# --- Optimization Loop ---
for i in tqdm(range(iters)):
opt.zero_grad()
# Build 3D Affine Matrix
A = build_affine_matrix_3d(rx, ry, rz, tx, ty, tz, sx, sy, sz, device, ref.dtype)
# Grid Sample 3D
# Note: 5D grid_sample expects (N, C, D, H, W)
grid = F.affine_grid(A, size=ref.shape, align_corners=False)
warped = F.grid_sample(mov, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
L = loss_fn(warped, ref)
L.backward()
opt.step()
if verbose and (i % max(1, iters // 5) == 0 or i == iters-1):
print(f"[{i+1}/{iters}] Loss={L.item():.5f}")
return warped.squeeze().detach().cpu().numpy(), {
"rx": float(rx.detach().cpu()), "ry": float(ry.detach().cpu()), "rz": float(rz.detach().cpu()),
"tx": float(tx.detach().cpu()), "ty": float(ty.detach().cpu()), "tz": float(tz.detach().cpu()),
"sx": float(sx.detach().cpu()), "sy": float(sy.detach().cpu()), "sz": float(sz.detach().cpu()),
}