# -*- coding: utf-8 -*-
"""
PyTorch functions for 2D and 3D tomography and Radon transform simulation.
This module includes differentiable and non-differentiable implementations of:
- Forward and back projection routines (Radon and inverse Radon transforms) in 2D and 3D.
- Iterative reconstruction methods such as SIRT and CGLS using functional forward/back projectors.
- Sparse matrix-based forward and backward operations (A-matrix formulation) using PyTorch sparse tensors.
- Utility functions for constructing sparse system matrices, rotating images, and defining affine transforms.
Main features:
- Differentiable 3D forward and backward projectors using `torchvision.transforms.functional.rotate`.
- Support for iterative solvers: SIRT (with normalization) and CGLS.
- Conversion utilities for using SciPy sparse matrices in PyTorch (e.g., `Amatrix_torch`, `Sino_torch`, `Amatrix_rec`).
- Grid-based rotation of images via affine transformation (`imrotate_torch`).
- Compatible with both CPU and CUDA devices.
Author: Antony Vamvakeros
"""
import torch
from torchvision.transforms.functional import rotate
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode
from numpy import vstack
import numpy as np
[docs]
def forward_project_3D(vol, angles, npix, nch, device='cuda'):
"""
Perform forward projection (Radon transform) of a 3D volume using PyTorch.
Parameters
----------
vol : torch.Tensor
Input volume of shape (1, nch, npix, npix), where:
- 1 is the batch dimension,
- nch is the number of channels or slices (e.g., spectral bins or time steps),
- npix x npix is the spatial resolution of each slice.
angles : list or ndarray
List of projection angles in degrees.
npix : int
Number of pixels along each projection axis (image size).
nch : int
Number of slices or channels in the volume.
device : str, optional
PyTorch device string (default: 'cuda').
Returns
-------
sinos : torch.Tensor
Simulated sinogram of shape (nch, npix, len(angles)), where each slice
corresponds to a different channel/slice and each column is a projection.
"""
sinos = torch.zeros((nch, npix, len(angles)), device=device)
for angle in range(len(angles)):
vol_rot = rotate(vol, float(angles[angle]), interpolation=InterpolationMode.BILINEAR)
sinos[:,:,angle] = torch.sum(vol_rot, dim=3)[0,:,:]
return sinos
[docs]
def back_project_3D(sinos, angles, npix, nch, device='cuda'):
"""
Perform backprojection (inverse Radon transform) of a 3D sinogram using PyTorch.
Parameters
----------
sinos : torch.Tensor
Input sinogram of shape (nch, npix, len(angles)), where:
- nch is the number of slices or channels,
- npix is the number of detector elements,
- len(angles) is the number of projection angles.
angles : list or ndarray
List of projection angles in degrees.
npix : int
Number of pixels in the output reconstructed image.
nch : int
Number of slices or channels in the volume.
device : str, optional
PyTorch device string (default: 'cuda').
Returns
-------
vol : torch.Tensor
Reconstructed 3D volume of shape (nch, npix, npix). Each channel corresponds to
a separate slice, reconstructed via filtered or unfiltered backprojection.
"""
vol = torch.zeros((1, nch, npix, npix), device=device)
for angle in range(len(angles)):
vol_rot = sinos[:,:,angle].unsqueeze(0).unsqueeze(0)
vol_rot = torch.transpose(vol_rot, 2, 1)
vol_rot = vol_rot.repeat(1, 1, npix, 1)
vol_rot = torch.transpose(vol_rot, 3, 2)
vol_rot = rotate(vol_rot,
-float(angles[angle]), interpolation=InterpolationMode.BILINEAR)
vol += vol_rot
return vol[0, :, :, :]
# Compute W_ray for SIRT
[docs]
def compute_W_ray(angles, npix, nch, device='cuda'):
"""
Compute ray normalization weights for each voxel via forward projection of a constant volume.
This function simulates the accumulation of contributions each detector sees from a
uniform volume, useful for SIRT or SART-type normalization.
Parameters
----------
angles : list or ndarray
List of projection angles in degrees.
npix : int
Number of pixels in each dimension of the image (image size).
nch : int
Number of slices or channels in the 3D volume.
device : str, optional
PyTorch device string (default: 'cuda').
Returns
-------
W_ray : torch.Tensor
Weighting map of shape (nch, npix, len(angles)) representing forward projection of ones.
"""
vol = torch.ones((1, nch, npix, npix), dtype=torch.float32, device=device)
W_ray = forward_project_3D(vol, angles, npix, nch, device)
return W_ray
[docs]
def sirt_pytorch_functional(sinos, angles, npix, nch=1, n_iter=20, relax=0.01, epsilon=1e-6, device='cuda'):
"""
SIRT reconstruction using PyTorch with function-based forward and backward projectors.
Parameters
----------
sinos : torch.Tensor
Input sinogram tensor of shape (nch, npix, n_angles), e.g. (1, 151, 180).
angles : list or ndarray
List of projection angles in degrees.
npix : int
Width/height of the reconstructed image.
nch : int
Number of slices or channels in the volume (default = 1).
n_iter : int
Number of SIRT iterations.
relax : float
Relaxation factor (typically small, e.g., 0.01).
epsilon : float
Small number to avoid division by zero.
device : str
Computation device, e.g. 'cuda'.
Returns
-------
torch.Tensor
Reconstructed volume of shape (nch, npix, npix).
"""
sinos = sinos.to(device)
W_ray = compute_W_ray(angles, npix, nch, device=device)
vol = torch.zeros((1, nch, npix, npix), dtype=torch.float32, device=device)
for _ in range(n_iter):
sim = forward_project_3D(vol, angles, npix, nch, device)
residual = sinos - sim
correction = back_project_3D(residual / (W_ray + epsilon), angles, npix, nch, device)
vol += relax * correction
if device == 'cuda':
vol = vol.cpu()
vol = vol.squeeze().numpy()
vol = np.transpose(vol)
return vol # shape: (nch, npix, npix)
[docs]
def cgls_pytorch_functional(sinos, angles, npix, nch=1, n_iter=10, device='cuda'):
"""
CGLS reconstruction using PyTorch with functional forward and back projectors.
Parameters
----------
sinos : torch.Tensor
Input sinogram tensor of shape (nch, npix, n_angles), e.g., (1, 151, 180).
angles : list or ndarray
List of projection angles in degrees.
npix : int
Width/height of the reconstructed image.
nch : int
Number of slices or channels in the volume (default = 1).
n_iter : int
Number of CGLS iterations.
device : str
Computation device, e.g., 'cuda'.
Returns
-------
torch.Tensor
Reconstructed volume of shape (nch, npix, npix).
"""
sinos = sinos.to(device)
def forward(x):
x_batched = x.unsqueeze(0) # (1, nch, npix, npix)
return forward_project_3D(x_batched, angles, npix, nch, device)
def backward(y):
return back_project_3D(y, angles, npix, nch, device)
x = torch.zeros((nch, npix, npix), dtype=torch.float32, device=device)
b = sinos
r = b - forward(x)
p = backward(r)
d = p.clone()
delta_new = torch.sum(d * d)
for _ in range(n_iter):
q = forward(d)
alpha = delta_new / torch.sum(q * q)
x += alpha * d
r -= alpha * q
s = backward(r)
delta_old = delta_new
delta_new = torch.sum(s * s)
beta = delta_new / delta_old
d = s + beta * d
if device == 'cuda':
x = x.cpu()
x = x.squeeze().numpy()
x = np.transpose(x)
return x
[docs]
def Amatrix_torch(A, gpu=True):
"""
Converts a SciPy sparse matrix A to a PyTorch sparse tensor.
Parameters
----------
A : scipy.sparse matrix
Input sparse matrix in COO format (or convertible to COO).
gpu : bool, optional
If True, moves the tensor to CUDA. Default is True.
Returns
-------
Atorch : torch.sparse.FloatTensor
PyTorch sparse tensor version of A.
"""
Acoo = A.tocoo()
values = Acoo.data
indices = vstack((Acoo.row, Acoo.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = Acoo.shape
Atorch = torch.sparse.FloatTensor(i, v, torch.Size(shape))
return Atorch.cuda() if gpu else Atorch
[docs]
def Sino_torch(Atorch, im, ntr, npr):
"""
Generates a sinogram by applying a torch sparse A matrix to an image.
Parameters
----------
Atorch : torch.sparse.FloatTensor
Sparse A matrix in torch format.
im : ndarray
Input image as a NumPy array of shape (ntr, ntr).
ntr : int
Number of translation steps (image side length).
npr : int
Number of projections.
Returns
-------
s : torch.Tensor
Sinogram as a tensor of shape (npr, ntr).
"""
imt = torch.from_numpy(im).float().reshape(ntr * ntr, 1).cuda()
s = Amatrix_sino(Atorch, imt, npr, ntr)
return s
[docs]
def Amatrix_sino(Atorch, im, npr, ntr):
"""
Computes the sinogram using matrix multiplication with a sparse A matrix.
Parameters
----------
Atorch : torch.sparse.FloatTensor
Sparse system matrix A (shape: [npr*ntr, ntr*ntr]).
im : torch.Tensor
Flattened image tensor of shape (ntr*ntr, 1).
npr : int
Number of projections (angles).
ntr : int
Number of translation steps.
Returns
-------
stf : torch.Tensor
Sinogram tensor of shape (npr, ntr).
"""
stf = torch.matmul(Atorch, im)
stf = stf.reshape(npr, ntr)
return stf
[docs]
def Amatrix_rec(AtorchT, s, ntr):
"""
Reconstructs an image from a sinogram using the transpose of the A matrix.
Parameters
----------
AtorchT : torch.sparse.FloatTensor
Transpose of the system matrix A (shape: [ntr*ntr, npr*ntr]).
s : torch.Tensor
Sinogram of shape (npr, ntr).
ntr : int
Number of translation steps (output image side length).
Returns
-------
rec : torch.Tensor
Reconstructed image tensor of shape (ntr, ntr).
"""
rec = torch.matmul(AtorchT, s.view(-1, 1))
return rec.reshape(ntr, ntr)
[docs]
def RotMat(theta):
"""
Creates a 2D rotation matrix for use in affine transformations.
Parameters
----------
theta : float or torch.Tensor
Rotation angle in radians.
Returns
-------
rotmat : torch.Tensor
A 2×3 affine rotation matrix tensor.
"""
theta = torch.tensor(theta)
rotmat = torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
return rotmat
[docs]
def imrotate_torch(im, theta, dtype=torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor):
"""
Rotates a 2D image (or batch) using an affine transformation.
Parameters
----------
im : torch.Tensor
Input tensor of shape (N, C, H, W).
theta : float
Rotation angle in radians.
dtype : torch dtype, optional
Data type of the affine matrix and grid. Default is float tensor on GPU.
Returns
-------
imr : torch.Tensor
Rotated image of the same shape as input.
"""
rot_mat = RotMat(theta)[None, ...].type(dtype).repeat(im.shape[0], 1, 1)
grid = F.affine_grid(rot_mat, im.size(), align_corners=True).type(dtype)
imr = F.grid_sample(im, grid, align_corners=True)
return imr
[docs]
def create_torch_Amat(Acoo, values, indices, shape, device='cuda'):
"""
Constructs a PyTorch sparse A matrix from COO components.
Parameters
----------
Acoo : scipy.sparse.coo_matrix
Sparse matrix object (for context).
values : ndarray
Non-zero values of the matrix.
indices : ndarray
2×N array of row and column indices.
shape : tuple
Shape of the matrix.
device : str, optional
Target device ('cuda' or 'cpu').
Returns
-------
Amat : torch.sparse.FloatTensor
Constructed sparse matrix on the target device.
"""
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to(device)