# -*- coding: utf-8 -*-
"""
Utility functions for patch-based sampling, index generation, total variation (TV) regularization,
and structural similarity (SSIM) loss in 2D and 3D tensor volumes. These are commonly used in image
reconstruction, inverse problems, and spectral imaging applications.
Author: Antony Vamvakeros
Contents:
- Index generation utilities:
* generate_indices: Draws (row, col) indices from uniform, normal, or Sobol distributions.
* generate_sobol_indices_batch: Batch generation of Sobol-distributed indices.
* draw_valid_indices: Wrapper for drawing valid indices using masks and custom sampling.
* filter_patch_indices: Filters mask into valid patch locations.
* calc_patches_indices: Extracts tensor patches based on index positions.
* initialize_counter, update_counter: Track pixel utilization in sampling-based methods.
- Normalization:
* denormalize: Converts normalized parameter values [0, 1] back to their physical scale.
- Total Variation (TV) regularization:
* tv_spatial: TV along spatial dimensions (y, x) of a 2D image.
* tv_spectral: TV along spectral (channel) dimension.
* tv_3d_spectral: TV across both spectral and spatial dimensions (channel, y, x).
* tv_3d: TV in full 3D volumes (z, y, x), supporting 5D tensors (B, C, D, H, W).
- Structural Similarity (SSIM) loss functions:
* SSIM2D: Perceptual similarity loss for 2D images based on luminance, contrast, and structure.
* SSIM3DLoss: SSIM loss adapted for 3D volumes using a uniform kernel.
Notes:
All functions are implemented with PyTorch and designed to run on either CPU or CUDA.
The TV and SSIM losses are differentiable and can be integrated directly into model training loops
for regularization or perceptual fidelity in reconstruction tasks.
"""
import torch
import torch.nn.functional as F
from torch.quasirandom import SobolEngine
from torch import nn
import numpy as np
[docs]
def denormalize(param, param_name, param_min, param_max, peak_number=None):
"""
Denormalize the parameter from [0, 1] to its original range.
Parameters:
param (float): The parameter to be denormalized.
param_name (str): The name of the parameter.
param_min (dict): A dictionary containing the minimum values of the parameters.
param_max (dict): A dictionary containing the maximum values of the parameters.
peak_number (int, optional): The peak number for the parameter. Default is None.
Returns:
float: The denormalized parameter.
"""
if peak_number is None: # For Slope and Intercept
return param_min[param_name] + (param_max[param_name] - param_min[param_name]) * param
return param_min[param_name][peak_number] + (param_max[param_name][peak_number] - param_min[param_name][peak_number]) * param
[docs]
def calc_patches_indices(indices, tensor, patch_size, use_middle=False):
"""
Calculate the indices of the patches to be selected based on the provided indices.
Users can choose to use the middle of the patch or the top-left as the reference point.
Parameters:
indices (list of tuples): A list of tuples containing the starting indices of the patches.
tensor (torch.Tensor): The tensor from which the patches are to be selected.
patch_size (int): The size of the patches.
use_middle (bool): If True, use the middle pixel as the reference. If False, use the top-left.
Returns:
torch.Tensor: A tensor containing the selected patches.
"""
selected_patches = []
if use_middle:
half_patch = int(patch_size / 2)
for (h_start, w_start) in indices:
h_center = h_start - half_patch
w_center = w_start - half_patch
patch = tensor[..., h_center:h_center + patch_size, w_center:w_center + patch_size]
selected_patches.append(patch)
else:
for (h_start, w_start) in indices:
patch = tensor[..., h_start:h_start + patch_size, w_start:w_start + patch_size]
selected_patches.append(patch)
# Concatenate patches for processing
return torch.cat(selected_patches, dim=0)
[docs]
def generate_sobol_indices_batch(rows, cols, batch_size, patch_size, device):
"""
Generate a batch of Sobol indices for a given dimension and batch size, considering the patch size to avoid boundary issues.
Parameters:
rows (int): The total number of rows in the dataset.
cols (int): The total number of columns in the dataset.
batch_size (int): The number of indices to generate.
patch_size (int): The size of the patch to consider for boundary adjustment.
device (torch.device): The device to use for generating indices.
Returns:
tuple: Two tensors containing the row indices and column indices.
"""
sobol = SobolEngine(dimension=2, scramble=True)
# Define the range for indices to avoid boundary issues
max_row = rows - patch_size
max_col = cols - patch_size
# Generate indices on the CPU
points = sobol.draw(batch_size)
# Move the generated points to the specified device and scale them
points = points.to(device) * torch.tensor([max_row, max_col], device=device, dtype=torch.float32)
points = points.int()
return points[:, 0], points[:, 1]
[docs]
def generate_indices(rows, cols, num_indices, patch_size, distribution_type='uniform', std=3, mask=None, device='cuda', batched=False):
"""
Generate a specified number of valid (row, col) indices from a 2D grid, optionally constrained by a mask
and following a given sampling distribution.
Parameters
----------
rows : int
Number of rows in the 2D grid.
cols : int
Number of columns in the 2D grid.
num_indices : int
Total number of valid (row, col) pairs to generate.
patch_size : int
Size of the patch, used only for Sobol sampling mode.
distribution_type : str, optional
Sampling distribution to use: 'uniform', 'normal', or 'Sobol'. Default is 'uniform'.
std : float, optional
Standard deviation for 'normal' distribution sampling. Ignored for other distributions. Default is 3.
mask : torch.Tensor or None, optional
A binary mask (shape: [rows, cols]) where only non-zero locations are considered valid. Default is None.
device : str, optional
Device to perform computations on ('cuda' or 'cpu'). Default is 'cuda'.
batched : bool, optional
If True, samples are drawn in batches to speed up sampling. Default is False.
Returns
-------
indices : list of tuple[int, int]
A list of (row, col) tuples representing valid sampled locations, possibly mask-constrained.
"""
mean_row = float(rows) / 2
mean_col = float(cols) / 2
indices = []
while len(indices) < num_indices:
batch_size = num_indices - len(indices) if batched else 1
if distribution_type == 'normal':
sample_rows = torch.normal(mean=torch.full((batch_size,), mean_row, dtype=torch.float32, device=device),
std=torch.full((batch_size,), std, dtype=torch.float32, device=device)).round().int()
sample_cols = torch.normal(mean=torch.full((batch_size,), mean_col, dtype=torch.float32, device=device),
std=torch.full((batch_size,), std, dtype=torch.float32, device=device)).round().int()
elif distribution_type == 'uniform':
sample_rows = torch.randint(0, rows, (batch_size,), device=device)
sample_cols = torch.randint(0, cols, (batch_size,), device=device)
elif distribution_type == 'Sobol':
sample_rows, sample_cols = generate_sobol_indices_batch(rows, cols, batch_size, patch_size, device)
else:
raise ValueError("Invalid distribution type specified. Choose 'normal', 'uniform', or 'Sobol'.")
sample_rows = sample_rows.long() # Ensure indices are long for masking
sample_cols = sample_cols.long()
# Validate indices within range
valid = (sample_rows >= 0) & (sample_rows < rows) & (sample_cols >= 0) & (sample_cols < cols)
# Apply mask if available and ensure indices are used for mask access are already validated
if mask is not None and valid.any():
valid_indices = valid.nonzero(as_tuple=True)
valid_mask = mask[sample_rows[valid_indices], sample_cols[valid_indices]]
# Refine valid to include only those indices where mask is true
valid[valid_indices] = valid_mask
# Collect valid indices
valid_sample_rows = sample_rows[valid]
valid_sample_cols = sample_cols[valid]
indices.extend(zip(valid_sample_rows.tolist(), valid_sample_cols.tolist()))
if len(indices) >= num_indices:
indices = indices[:num_indices] # Ensure we do not exceed the number of requested indices
return indices
[docs]
def initialize_counter(rows, cols):
"""
Initialize a zero-filled counter matrix of the same size as the input matrix using PyTorch.
Parameters:
- matrix (2D torch tensor): Matrix whose dimensions will be used to create the counter.
Returns:
- 2D torch tensor: Initialized counter matrix.
"""
return torch.zeros((rows, cols), dtype=torch.float32)
[docs]
def update_counter(counter, indices, patch_size=(1, 1)):
"""
Update the pixel utilization counter based on the middle indices of extracted patches or pixels using PyTorch.
Parameters:
- counter (2D torch tensor): Counter matrix to be updated.
- indices (list of tuples): Indices of the top-left corners of the patches or pixels.
- patch_size (tuple): Size (height, width) of the patches or single pixel (default is 1x1 for pixels).
"""
patch_rows, patch_cols = patch_size
for r, c in indices:
# counter[r - int(patch_rows/2):r + int(patch_rows/2), c - int(patch_cols/2):c + int(patch_cols/2)] += 1
counter[r:r + int(patch_rows), c :c + int(patch_cols)] += 1
[docs]
def draw_valid_indices(rows, cols, num_indices, patch_size, distribution_type='normal',
batched=False, std_dev=3, mask=None, device='cuda'):
"""
Draw valid indices from a specified distribution.
Parameters:
rows (int): The total number of rows.
cols (int): The total number of columns.
num_indices (int): The number of indices to draw.
patch_size (int): The size of the square patch.
distribution_type (str): The type of distribution to draw from ('normal', 'uniform' or 'Sobol').
std_dev (float): The standard deviation for the normal distribution.
mask (torch.Tensor): A mask to apply to the indices.
device (torch.device): The device on which to generate the indices.
Returns:
list: A list of valid indices.
"""
if distribution_type == 'normal' and std_dev is None:
largest_dimension = max(rows, cols)
std_dev = largest_dimension / 3 # 3 sigma to span the largest dimension
indices = generate_indices(rows, cols, num_indices, patch_size, distribution_type=distribution_type,
std=std_dev, mask=mask, device=device, batched=batched)
return indices
[docs]
def filter_patch_indices(mask, patch_size):
"""
Divide a binary mask into non-overlapping patches and return the top-left coordinates
of patches that contain any non-zero elements.
Parameters
----------
mask : torch.Tensor
A 2D binary tensor indicating valid regions (non-zero entries).
patch_size : int
Size of the square patch to extract.
Returns
-------
patch_indices : list of tuple[int, int]
A list of (row, col) coordinates indicating the top-left corners of valid patches.
The list is randomly shuffled.
"""
# Generate indices for patches, filtering based on mask
patch_indices = []
for i in range(0, mask.shape[0], patch_size):
for j in range(0, mask.shape[1], patch_size):
mask_patch = mask[i:i + patch_size, j:j + patch_size]
# Check if the mask patch is all zeros
if not torch.all(mask_patch == 0):
patch_indices.append((i, j))
# Shuffle the indices randomly
if patch_indices: # Check if list is not empty
idx_shuffle = torch.randperm(len(patch_indices))
patch_indices = [patch_indices[i] for i in idx_shuffle]
return patch_indices
[docs]
def tv_spatial(x, isotropic=True, epsilon=1e-6):
"""
Compute Total Variation (TV) gradient along spatial dimensions (y, x).
Args:
x: Input volume of shape (1, nch, npix, npix).
isotropic: If True, use isotropic TV. If False, use anisotropic TV.
epsilon: Small value to prevent division by zero (for isotropic TV).
Returns:
TV gradient of the same shape as x.
"""
# Compute spatial gradients
grad_y = x[:, :, 1:, :] - x[:, :, :-1, :] # Difference along y-axis
grad_x = x[:, :, :, 1:] - x[:, :, :, :-1] # Difference along x-axis
# Pad gradients to match original size
grad_y = F.pad(grad_y, (0, 0, 1, 0)) # Pad along y-dimension
grad_x = F.pad(grad_x, (1, 0, 0, 0)) # Pad along x-dimension
if isotropic:
# Isotropic: Gradient magnitude
grad_norm = torch.sqrt(grad_y**2 + grad_x**2 + epsilon)
grad_y /= grad_norm
grad_x /= grad_norm
else:
# Anisotropic: No normalization
grad_y = torch.sign(grad_y)
grad_x = torch.sign(grad_x)
# Compute divergence
div_y = torch.cat((grad_y[:, :, :1, :],
grad_y[:, :, 1:, :] - grad_y[:, :, :-1, :],
-grad_y[:, :, -1:, :]), dim=2)
div_x = torch.cat((grad_x[:, :, :, :1],
grad_x[:, :, :, 1:] - grad_x[:, :, :, :-1],
-grad_x[:, :, :, -1:]), dim=3)
# Total variation gradient
tv_grad = div_y + div_x
return tv_grad
[docs]
def tv_spectral(x, isotropic=True, epsilon=1e-6):
"""
Compute Total Variation (TV) gradient along the spectral dimension (nch).
Args:
x: Input volume of shape (1, nch, npix, npix).
isotropic: If True, use isotropic TV. If False, use anisotropic TV.
epsilon: Small value to prevent division by zero (for isotropic TV).
Returns:
TV gradient of the same shape as x.
"""
# Compute spectral gradient
grad_spec = x[:, 1:, :, :] - x[:, :-1, :, :] # Difference along spectral (nch) dimension
if isotropic:
# Isotropic: Gradient magnitude
grad_norm = torch.sqrt(grad_spec**2 + epsilon)
grad_spec = grad_spec / (grad_norm + epsilon) # Normalize gradient
else:
# Anisotropic: No normalization
grad_spec = torch.sign(grad_spec)
# Compute divergence
div_spec = torch.cat((grad_spec[:, :1, :, :],
grad_spec[:, 1:, :, :] - grad_spec[:, :-1, :, :],
-grad_spec[:, -1:, :, :]), dim=1)
return div_spec
[docs]
def tv_3d_spectral(x, isotropic=True, epsilon=1e-6):
"""
Compute Total Variation (TV) gradient along spectral and spatial dimensions (nch, y, x).
Args:
x: Input volume of shape (1, nch, npix, npix).
isotropic: If True, use isotropic TV. If False, use anisotropic TV.
epsilon: Small value to prevent division by zero (for isotropic TV).
Returns:
TV gradient of the same shape as x.
"""
# Compute gradients
grad_spec = x[:, 1:, :, :] - x[:, :-1, :, :] # Spectral gradient (nch dimension)
grad_y = x[:, :, 1:, :] - x[:, :, :-1, :] # Spatial gradient along y-axis
grad_x = x[:, :, :, 1:] - x[:, :, :, :-1] # Spatial gradient along x-axis
# Pad gradients to match the original size
grad_spec = F.pad(grad_spec, (0, 0, 0, 0, 1, 0)) # Pad along nch dimension
grad_y = F.pad(grad_y, (0, 0, 1, 0)) # Pad along y-dimension
grad_x = F.pad(grad_x, (1, 0, 0, 0)) # Pad along x-dimension
if isotropic:
# Isotropic: Gradient magnitude
grad_norm = torch.sqrt(grad_spec**2 + grad_y**2 + grad_x**2 + epsilon)
grad_spec /= grad_norm
grad_y /= grad_norm
grad_x /= grad_norm
else:
# Anisotropic: No normalization
grad_spec = torch.sign(grad_spec)
grad_y = torch.sign(grad_y)
grad_x = torch.sign(grad_x)
# Compute divergences
div_spec = F.pad(grad_spec[:, 1:, :, :] - grad_spec[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) # Spectral divergence
div_y = F.pad(grad_y[:, :, 1:, :] - grad_y[:, :, :-1, :], (0, 0, 1, 0)) # y-axis divergence
div_x = F.pad(grad_x[:, :, :, 1:] - grad_x[:, :, :, :-1], (1, 0, 0, 0)) # x-axis divergence
# Total variation gradient
tv_grad = div_spec + div_y + div_x
return tv_grad
[docs]
def tv_3d(x, isotropic=True, epsilon=1e-6):
"""
Compute Total Variation (TV) gradient for 3D volumes along all three axes (x, y, z).
Args:
x: Input volume of shape (1, nch, D, H, W).
isotropic: If True, use isotropic TV. If False, use anisotropic TV.
epsilon: Small value to prevent division by zero (for isotropic TV).
Returns:
TV gradient of the same shape as x.
"""
# Compute 3D gradients
grad_z = x[:, :, 1:, :, :] - x[:, :, :-1, :, :] # Difference along z-axis
grad_y = x[:, :, :, 1:, :] - x[:, :, :, :-1, :] # Difference along y-axis
grad_x = x[:, :, :, :, 1:] - x[:, :, :, :, :-1] # Difference along x-axis
# Pad gradients to match the original size
grad_z = F.pad(grad_z, (0, 0, 0, 0, 1, 0)) # Pad along z-dimension
grad_y = F.pad(grad_y, (0, 0, 1, 0, 0, 0)) # Pad along y-dimension
grad_x = F.pad(grad_x, (1, 0, 0, 0, 0, 0)) # Pad along x-dimension
if isotropic:
# Isotropic: Gradient magnitude
grad_norm = torch.sqrt(grad_z**2 + grad_y**2 + grad_x**2 + epsilon)
grad_z /= grad_norm
grad_y /= grad_norm
grad_x /= grad_norm
else:
# Anisotropic: No normalization
grad_z = torch.sign(grad_z)
grad_y = torch.sign(grad_y)
grad_x = torch.sign(grad_x)
# Compute divergences
div_z = F.pad(grad_z[:, :, 1:, :, :] - grad_z[:, :, :-1, :, :], (0, 0, 0, 0, 1, 0)) # Divergence along z
div_y = F.pad(grad_y[:, :, :, 1:, :] - grad_y[:, :, :, :-1, :], (0, 0, 1, 0, 0, 0)) # Divergence along y
div_x = F.pad(grad_x[:, :, :, :, 1:] - grad_x[:, :, :, :, :-1], (1, 0, 0, 0, 0, 0)) # Divergence along x
# Total variation gradient
tv_grad = div_z + div_y + div_x
return tv_grad
[docs]
class SSIM2D(torch.nn.Module):
"""
A PyTorch module for computing the Structural Similarity Index (SSIM) between two 2D images.
SSIM is a perceptual metric that captures similarity in terms of luminance, contrast, and structure.
This implementation supports single-channel 2D images (e.g. grayscale) and uses a 2D Gaussian
filter window for local statistics computation.
Parameters
----------
window_size : int, optional
Size of the Gaussian filter window. Default is 11.
sigma : float, optional
Standard deviation of the Gaussian kernel. Default is 1.5.
C1 : float, optional
Stabilizing constant for luminance term. Default is (0.01)^2.
C2 : float, optional
Stabilizing constant for contrast term. Default is (0.03)^2.
device : str, optional
Device to store the Gaussian window tensor. Default is 'cuda'.
Forward
-------
img1 : torch.Tensor
First image tensor of shape (N, 1, H, W).
img2 : torch.Tensor
Second image tensor of shape (N, 1, H, W).
Returns
-------
loss : torch.Tensor
A scalar tensor representing 1 - SSIM(img1, img2). Suitable for use as a loss function.
Notes
-----
- The output is `1 - SSIM`, so this module can be used directly as a loss function in optimization.
- Assumes input tensors are normalized to [0, 1].
- Only supports single-channel inputs (channel = 1).
"""
def __init__(self, window_size=11, sigma=1.5, C1=0.01**2, C2=0.03**2, device='cuda'):
super(SSIM2D, self).__init__()
self.window_size = window_size
self.C1 = C1
self.C2 = C2
self.channel = 1 # Adjust if using more channels
# Define Gaussian window
window = self.create_window(window_size, sigma)
self.window = window.to(device)
[docs]
@staticmethod
def gaussian(window_size, sigma):
# Create a tensor from 0 to window_size
x = torch.arange(window_size).float() - (window_size - 1) / 2
# Calculate the Gaussian function
gauss = torch.exp(-0.5 * (x / sigma) ** 2)
# Normalize to ensure the sum is 1
gauss = gauss / gauss.sum()
return gauss
[docs]
def create_window(self, window_size, sigma):
# Create 1D Gaussian window
g1d = self.gaussian(window_size, sigma)
# Use outer product to create 2D window
g2d = torch.outer(g1d, g1d)
# Add batch and channel dimensions [1, 1, H, W]
window = g2d.unsqueeze(0).unsqueeze(0)
return window
def _ssim(self, img1, img2):
mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2
SSIM_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))
return SSIM_map
[docs]
def forward(self, img1, img2):
ssim_map = self._ssim(img1, img2)
return 1 - ssim_map.mean()
[docs]
class SSIM3DLoss(nn.Module):
"""
Computes the Structural Similarity Index (SSIM) loss between two 3D volumes.
SSIM is a perceptual similarity metric that evaluates structural fidelity based on local
statistics (mean, variance, and covariance).
This implementation uses a uniform 3D window (not Gaussian) to approximate local statistics,
and is suitable for comparing volumes in tasks such as CT reconstruction, MRI denoising, or
3D image synthesis.
Parameters
----------
window_size : int, optional
Size of the cubic window used for computing local means and variances. Default is 11.
Forward
-------
x : torch.Tensor
First input volume of shape (D, H, W).
y : torch.Tensor
Second input volume of shape (D, H, W).
Returns
-------
loss : torch.Tensor
Scalar tensor representing `1 - SSIM(x, y)`, suitable for optimization as a loss function.
Notes
-----
- Input volumes are expected to be normalized to [0, 1].
- Adds singleton batch and channel dimensions internally.
- Assumes single-channel input; multi-channel support would require modification.
- The window is uniform (box filter) rather than Gaussian for simplicity.
"""
def __init__(self, window_size=11):
super(SSIM3DLoss, self).__init__()
self.window = self.create_3D_window(window_size).cuda() # Remove .cuda() if running on CPU
self.window_size = window_size
[docs]
def create_3D_window(self, window_size):
window = torch.ones(1, 1, window_size, window_size, window_size)
return window / window.numel()
[docs]
def forward(self, x, y):
# Add singleton dimensions for batch and channel
x = x.unsqueeze(0).unsqueeze(0)
y = y.unsqueeze(0).unsqueeze(0)
mu_x = F.conv3d(x, self.window, padding=self.window_size // 2, groups=1)
mu_y = F.conv3d(y, self.window, padding=self.window_size // 2, groups=1)
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu_x_mu_y = mu_x * mu_y
mu_x_sq = mu_x.pow(2)
mu_y_sq = mu_y.pow(2)
sigma_x_sq = F.conv3d(x * x, self.window, padding=self.window_size // 2, groups=1) - mu_x_sq
sigma_y_sq = F.conv3d(y * y, self.window, padding=self.window_size // 2, groups=1) - mu_y_sq
sigma_xy = F.conv3d(x * y, self.window, padding=self.window_size // 2, groups=1) - mu_x_mu_y
ssim_map = ((2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)) / ((mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2))
# Remove singleton dimensions
return 1 - ssim_map.mean()