Source code for nDTomo.pytorch.models_torch

# -*- coding: utf-8 -*-
"""
Neural networks models

@author: Antony Vamvakeros
"""

#%%

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

[docs] class PeakFitCNN(nn.Module): """ A 2D convolutional neural network designed for upsampling and refining spectral or spatial peak data, optionally doubling or quadrupling the input resolution using bilinear interpolation and CNN blocks. Parameters ---------- nch_in : int Number of input channels. nch_out : int Number of output channels. nfilts : int Number of filters in the intermediate convolution layers. upscale_factor : int Upscaling factor for the input. Supported values: 2 or 4. norm_type : str Type of normalization to apply: 'instance', 'batch', or 'layer'. activation : str Final activation function: 'Linear', 'ReLU', 'Sigmoid', or 'LeakyReLU'. padding : str Padding mode for convolutions ('same' or 'valid'). npix : int Number of pixels in the input (required for LayerNorm). Forward ------- x : torch.Tensor Input tensor of shape (batch_size, nch_in, H, W). Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, H * upscale_factor, W * upscale_factor). """ def __init__(self, nch_in=1, nch_out=1, nfilts=32, upscale_factor = 4, norm_type='instance', activation='Linear', padding='same', npix=None): super(PeakFitCNN, self).__init__() self.npix = npix self.upscale_factor = upscale_factor # Initial feature extraction self.input = nn.Conv2d(nch_in, nfilts, kernel_size=3, stride=1, padding=padding, bias=True) layers = [] layers.append(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)) layers.append(nn.Conv2d(nfilts, nfilts, kernel_size=3, stride=1, padding=padding, bias=True)) # Add normalization based on norm_type if norm_type == "instance": layers.append(nn.InstanceNorm2d(nfilts, affine=True)) elif norm_type == "batch": layers.append(nn.BatchNorm2d(nfilts)) elif norm_type == "layer": layers.append(nn.LayerNorm([nfilts, 2*self.npix, 2*self.npix])) # Add activation function layers.append(nn.ReLU(inplace=True)) self.upsample1 = nn.Sequential(*layers) if self.upscale_factor == 4: layers = [] layers.append(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)) layers.append(nn.Conv2d(nfilts, nfilts, kernel_size=3, stride=1, padding=padding, bias=True)) # Add normalization based on norm_type if norm_type == "instance": layers.append(nn.InstanceNorm2d(nfilts, affine=True)) elif norm_type == "batch": layers.append(nn.BatchNorm2d(nfilts)) elif norm_type == "layer": layers.append(nn.LayerNorm([nfilts, 4*self.npix, 4*self.npix])) # Add activation function layers.append(nn.ReLU(inplace=True)) self.upsample2 = nn.Sequential(*layers) # Final output layer self.xrdct = nn.Conv2d(nfilts, nch_out, kernel_size=3, stride=1, padding=padding, bias=True) # Final activation self.final_activation = None if activation == "ReLU": self.final_activation = nn.ReLU() elif activation == "Sigmoid": self.final_activation = nn.Sigmoid() elif activation == "LeakyReLU": self.final_activation = nn.LeakyReLU(0.2, inplace=True)
[docs] def forward(self, x): # Feature maps from autoencoder2D are passed x = self.input(x) # Upsampling 1 x = self.upsample1(x) if self.upscale_factor == 4: # Upsampling 2 x = self.upsample2(x) # Output layer x = self.xrdct(x) if self.final_activation is not None: x = self.final_activation(x) return x
[docs] class PrmCNN2D(nn.Module): """ A flexible 2D model that combines a trainable tensor (image parameterization) with an optional CNN-based processing module. Can operate in three modes: - Pure parameterization (learned image). - CNN only (applies CNN to input). - Parameterization + CNN (CNN applied to learned image). Parameters ---------- npix : int Image resolution (assumes square images). nch_in : int Number of input channels. nch_out : int Number of output channels. nfilts : int Number of filters in CNN layers. nlayers : int Number of intermediate CNN blocks (excluding first and last layers). norm_type : str Type of normalization: 'layer', 'instance', or 'batchnorm'. prms_layer : bool If True, a learnable tensor is used as input. cnn_layer : bool If True, a CNN processes the input or parameter tensor. tensor_vals : str Initialization mode for the learned tensor: 'random', 'zeros', 'ones', 'mean', 'random_positive', or 'custom'. tensor_initial : torch.Tensor or None Custom tensor to use if tensor_vals == 'custom'. padding : str Padding mode for convolutions ('same' or 'valid'). Forward ------- x : torch.Tensor Input tensor if cnn_layer=True and prms_layer=False. Ignored otherwise. Returns ------- torch.Tensor Output tensor of shape (1, nch_out, npix, npix). """ def __init__(self, npix, nch_in=1, nch_out=1, nfilts=32, nlayers=4, norm_type='layer', prms_layer=True, cnn_layer=True, tensor_vals = 'random', tensor_initial = None, padding='same'): super(PrmCNN2D, self).__init__() self.npix = npix self.prms_layer = prms_layer self.cnn_layer = cnn_layer if self.prms_layer: if tensor_vals == 'random': self.initial_tensor = nn.Parameter(2*torch.randn(1, nch_in, npix, npix)-1) elif tensor_vals == 'zeros': self.initial_tensor = nn.Parameter(torch.zeros(1, nch_in, npix, npix)) elif tensor_vals == 'ones': self.initial_tensor = nn.Parameter(torch.ones(1, nch_in, npix, npix)) elif tensor_vals == 'mean': self.initial_tensor = nn.Parameter(0.5*torch.ones(1, nch_in, npix, npix)) elif tensor_vals == 'random_positive': self.initial_tensor = nn.Parameter(torch.randn(1, nch_in, npix, npix)) elif tensor_vals == 'custom': try: self.initial_tensor = nn.Parameter(tensor_initial) except: print('Custom tensor not provided. Using random tensor instead') self.initial_tensor = nn.Parameter(torch.randn(1, nch_in, npix, npix)) if self.cnn_layer: layers = [] layers.append(nn.Conv2d(nch_in, nfilts, kernel_size=3, stride=1, padding=padding)) # 'same' padding in PyTorch is usually done by manually specifying the padding if norm_type=='layer': if padding=='valid': layers.append(nn.LayerNorm([nfilts, self.npix -2, self.npix -2])) else: layers.append(nn.LayerNorm([nfilts, self.npix, self.npix])) elif norm_type=='instance': layers.append(nn.InstanceNorm2d(nfilts, affine = True)) elif norm_type=='batchnorm': layers.append(nn.BatchNorm2d(nfilts)) layers.append(nn.ReLU()) for layer in range(nlayers): layers.append(nn.Conv2d(nfilts, nfilts, kernel_size=3, stride=1, padding=padding)) if norm_type=='layer': if padding=='valid': layers.append(nn.LayerNorm([nfilts, self.npix -2*(layer + 2), self.npix -2*(layer + 2)])) else: layers.append(nn.LayerNorm([nfilts, self.npix, self.npix])) elif norm_type=='instance': layers.append(nn.InstanceNorm2d(nfilts, affine = True)) elif norm_type=='batchnorm': layers.append(nn.BatchNorm2d(nfilts)) layers.append(nn.ReLU()) layers.append(nn.Conv2d(nfilts, nch_out, kernel_size=3, stride=1, padding=padding)) layers.append(nn.Sigmoid()) self.cnn2d = nn.Sequential(*layers)
[docs] def forward(self, x): if self.prms_layer and self.cnn_layer: out = self.cnn2d(torch.sigmoid(self.initial_tensor)) elif self.cnn_layer and not self.prms_layer: out = self.cnn2d(x) elif self.prms_layer and not self.cnn_layer: out = torch.sigmoid(self.initial_tensor) return out
[docs] class CNN1D(nn.Module): """ A 1D convolutional neural network for sequential or spectral data processing with optional normalization and residual connections. Parameters ---------- nch_in : int Number of input channels. nch_out : int Number of output channels. nfilts : int Number of filters in the convolutional layers. nlayers : int Number of intermediate convolutional blocks. norm_type : str or None Type of normalization: 'batch', 'layer', or None. activation : str Final activation type. If 'Sigmoid', a Sigmoid activation is appended after the last layer. Forward ------- x : torch.Tensor Input tensor of shape (batch_size, nch_in, sequence_length). residual : bool If True, adds input x to the output (residual connection). Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, sequence_length). """ def __init__(self, nch_in=1, nch_out=1, nfilts=32, nlayers=4, norm_type='layer', activation='Linear'): super(CNN1D, self).__init__() layers = [] layers.append(nn.Conv1d(nch_in, nfilts, kernel_size=3, stride=1, padding=1)) if norm_type is not None: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) for layer in range(nlayers): layers.append(nn.Conv1d(nfilts, nfilts, kernel_size=3, stride=1, padding=1)) if norm_type is not None: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) layers.append(nn.Conv1d(nfilts, nch_out, kernel_size=3, stride=1, padding=1)) if activation == 'Sigmoid': layers.append(nn.Sigmoid()) self.cnn1d = nn.Sequential(*layers)
[docs] def add_norm_layer(self, layers, nfilts, norm_type): if norm_type == 'batch': layers.append(nn.BatchNorm1d(nfilts)) elif norm_type == 'layer': layers.append(nn.LayerNorm([nfilts])) else: raise ValueError('Invalid normalization type')
[docs] def forward(self, x, residual=False): if residual: out = self.cnn1d(x) + x else: out = self.cnn1d(x) return out
[docs] class CNN2D(nn.Module): """ A configurable 2D Convolutional Neural Network (CNN) for image processing tasks. Parameters ---------- npix : int The spatial size (height and width) of the input images. nch_in : int, optional, default=1 Number of input channels (e.g., 1 for grayscale images, 3 for RGB). nch_out : int, optional, default=1 Number of output channels. nfilts : int, optional, default=32 Number of filters (channels) in the intermediate convolutional layers. nlayers : int, optional, default=4 Number of intermediate convolutional layers. norm_type : str or None, optional, default='layer' Normalization type to apply after convolutions: - 'batch': Batch normalization. - 'layer': Layer normalization. - None: No normalization. activation : str, optional, default='Linear' Final activation function to apply: - 'Sigmoid': Applies a sigmoid activation. - 'Linear': No activation applied. Methods ------- forward(x, residual=False): Performs a forward pass through the network. - `residual`: If True, adds the input `x` to the output. """ def __init__(self, npix, nch_in=1, nch_out=1, nfilts=32, nlayers=4, norm_type='layer', activation='Linear'): super(CNN2D, self).__init__() self.npix = npix layers = [] # Input convolutional layer layers.append(nn.Conv2d(nch_in, nfilts, kernel_size=3, stride=1, padding=1)) # Same padding if norm_type: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) # Intermediate convolutional layers for _ in range(nlayers): layers.append(nn.Conv2d(nfilts, nfilts, kernel_size=3, stride=1, padding=1)) if norm_type: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) # Output convolutional layer layers.append(nn.Conv2d(nfilts, nch_out, kernel_size=3, stride=1, padding=1)) if activation == 'Sigmoid': layers.append(nn.Sigmoid()) # Combine all layers into a sequential model self.cnn2d = nn.Sequential(*layers)
[docs] def add_norm_layer(self, layers, nfilts, norm_type): """ Adds a normalization layer to the list of layers. Parameters ---------- layers : list List of layers to which the normalization layer is appended. nfilts : int Number of channels for normalization. norm_type : str Type of normalization ('batch' or 'layer'). """ if norm_type == 'batch': layers.append(nn.BatchNorm2d(nfilts)) elif norm_type == 'layer': layers.append(nn.LayerNorm([nfilts, self.npix, self.npix])) else: raise ValueError(f"Invalid normalization type: {norm_type}. Choose 'batch', 'layer', or None.")
[docs] def forward(self, x, residual=False): """ Performs a forward pass through the network. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, nch_in, npix, npix). residual : bool, optional, default=False If True, adds the input `x` to the output. Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, npix, npix). """ if residual: out = self.cnn2d(x) + x else: out = self.cnn2d(x) return out
[docs] class CNN3D(nn.Module): """ A 3D convolutional neural network for volumetric data processing with optional normalization and configurable depth. Parameters ---------- npix : int Size of the 3D cube (assumes cube of shape npix x npix x npix). nch_in : int Number of input channels. nch_out : int Number of output channels. nfilts : int Number of filters in convolutional layers. nlayers : int Number of intermediate convolutional blocks. norm_type : str or None Type of normalization: 'batch', 'layer', or None. activation : str Final activation type. If 'Sigmoid', a Sigmoid activation is appended. Forward ------- x : torch.Tensor Input tensor of shape (batch_size, nch_in, D, H, W). residual : bool If True, adds the input x to the output (residual connection). Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, D, H, W). """ def __init__(self, npix, nch_in=1, nch_out=1, nfilts=32, nlayers=4, norm_type='layer', activation='Linear'): super(CNN3D, self).__init__() self.npix = npix layers = [] layers.append(nn.Conv3d(nch_in, nfilts, kernel_size=3, stride=1, padding=1)) # 'same' padding in PyTorch is usually done by manually specifying the padding if norm_type is not None: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) for layer in range(nlayers): layers.append(nn.Conv3d(nfilts, nfilts, kernel_size=3, stride=1, padding=1)) if norm_type is not None: self.add_norm_layer(layers, nfilts, norm_type) layers.append(nn.ReLU()) layers.append(nn.Conv3d(nfilts, nch_out, kernel_size=3, stride=1, padding=1)) if activation == 'Sigmoid': layers.append(nn.Sigmoid()) self.cnn3d = nn.Sequential(*layers)
[docs] def add_norm_layer(self, layers, nfilts, norm_type): if norm_type == 'batch': layers.append(nn.BatchNorm3d(nfilts)) elif norm_type == 'layer': layers.append(nn.LayerNorm([nfilts, self.npix, self.npix, self.npix])) else: raise ValueError('Invalid normalization type')
[docs] def forward(self, x, residual=False): if residual: out = self.cnn3d(x) + x else: out = self.cnn3d(x) return out
[docs] class VolumeModel(nn.Module): """ A simple learnable 3D volume model where each slice in the volume is trainable. Supports differential updates or direct use of the internal volume. Parameters ---------- npix : int Number of pixels in each spatial dimension (assumes square slices). num_slices : int Number of slices along the depth axis. vol : torch.Tensor or None Optional initial volume of shape (num_slices, npix, npix). If None, initializes to zeros. device : str Device to place the model parameters on ('cuda' or 'cpu'). Forward ------- input_volume : torch.Tensor External volume to add to the internal volume (shape: num_slices, npix, npix). diff : bool If True, output is input_volume + self.volume. If False, output is self.volume only. Returns ------- torch.Tensor Transformed volume of shape (num_slices, npix, npix). """ def __init__(self, npix, num_slices, vol=None, device='cuda'): super(VolumeModel, self).__init__() self.num_slices = num_slices if vol is None: self.volume = nn.Parameter(torch.zeros((num_slices, npix, npix)).to(device)) else: self.volume = nn.Parameter(vol)
[docs] def forward(self, input_volume, diff = True): if diff: transformed_volume = input_volume + self.volume else: transformed_volume = self.volume return transformed_volume
[docs] class PeakModel(nn.Module): """ A peak fitting model that represents parameterized 1D functions (e.g. Gaussian or Pseudo-Voigt) using learnable normalized parameters constrained to [0, 1]. Converts normalized parameters to their physical range before evaluating the function. Parameters ---------- prms : dict Dictionary containing: - 'val': torch.Tensor of shape (n_params, npix, npix), the initial normalized parameters. - 'min': dict of minimum values for each parameter. - 'max': dict of maximum values for each parameter. device : str Device to place parameters on ('cuda' or 'cpu'). Forward ------- x : torch.Tensor Input tensor of shape (N, 1) representing the x-axis values for function evaluation. model : str Peak function model to use: 'Gaussian' or 'PseudoVoigt'. Returns ------- y : torch.Tensor Output tensor of shape (N, 1), the evaluated function for each pixel. """ def __init__(self, prms, device='cuda'): super(PeakModel, self).__init__() # Call the parent class's constructor first self.prms = nn.Parameter(prms['val']) self.min = prms['min'] self.max = prms['max'] self.nch = prms['val'].shape[0] self.npix = prms['val'].shape[1]
[docs] def forward(self, x, model = 'Gaussian'): prms = torch.reshape(self.prms, (self.nch, self.npix*self.npix)) prms = torch.transpose(prms, 1, 0) # Apply constraints to parameters prms = torch.clamp(prms, 0, 1) if model == 'Gaussian': y = ((self.min['Area'] + (self.max['Area']-self.min['Area'])*prms[:, 0:1]) * torch.exp(-(x - (self.min['Position'] + (self.max['Position'] - self.min['Position'])*prms[:, 1:2]))**2 / (2 * (self.min['FWHM'] + (self.max['FWHM'] - self.min['FWHM'])*prms[:, 2:3])**2)) + self.min['Slope'] + (self.max['Slope'] - self.min['Slope'])*prms[:, 3:4]*x + self.min['Intercept'] + (self.max['Intercept'] - self.min['Intercept'])*prms[:, 4:5]) elif model == 'PseudoVoigt': y = ((self.min['Fraction'] + (self.max['Fraction']-self.min['Fraction'])*prms[:, 5:6])* (self.min['Area'] + (self.max['Area']-self.min['Area'])*prms[:, 0:1]) * torch.exp(-(x - (self.min['Position'] + (self.max['Position'] - self.min['Position'])*prms[:, 1:2]))**2 / (2 * (self.min['FWHM'] + (self.max['FWHM'] - self.min['FWHM'])*prms[:, 2:3])**2)) + (1-(self.min['Fraction'] + (self.max['Fraction']-self.min['Fraction'])*prms[:, 5:6]))* ((self.min['Area'] + (self.max['Area']-self.min['Area'])*prms[:, 0:1])/ torch.pi) * ( (self.min['FWHM'] + (self.max['FWHM'] - self.min['FWHM'])*prms[:, 2:3]) / (((x - (self.min['Position'] + (self.max['Position'] - self.min['Position'])*prms[:, 1:2]))**2) + (self.min['FWHM'] + (self.max['FWHM'] - self.min['FWHM'])*prms[:, 2:3])**2)) + self.min['Slope'] + (self.max['Slope'] - self.min['Slope'])*prms[:, 3:4]*x + self.min['Intercept'] + (self.max['Intercept'] - self.min['Intercept'])*prms[:, 4:5]) return y
[docs] class ResNetBlock(nn.Module): """ A basic 2D residual block consisting of two convolutional layers with optional normalization and ReLU activation, followed by a skip connection. Parameters ---------- nfilts : int Number of filters in the convolutional layers. npix : int Spatial dimension (assumes square input for LayerNorm). kernel_size : int Size of convolutional kernel. stride : int Stride for the convolution. padding : int Padding to apply to convolution. norm_type : str or None Type of normalization: 'batch', 'layer', or None. Returns ------- torch.Tensor Output tensor of same shape as input. """ def __init__(self, nfilts, npix, kernel_size=3, stride=1, padding=1, norm_type=None): super(ResNetBlock, self).__init__() self.conv1 = nn.Conv2d(nfilts, nfilts, kernel_size, stride, padding) self.conv2 = nn.Conv2d(nfilts, nfilts, kernel_size, stride, padding) self.norm_type = norm_type self.npix = npix self.norm1 = self.add_norm_layer(nfilts) self.norm2 = self.add_norm_layer(nfilts)
[docs] def add_norm_layer(self, nfilts): if self.norm_type == 'batch': return nn.BatchNorm2d(nfilts) elif self.norm_type == 'layer': return nn.LayerNorm([nfilts, self.npix, self.npix]) return None
[docs] def forward(self, x): residual = x out = self.conv1(x) if self.norm1 is not None: out = self.norm1(out) out = F.relu(out) out = self.conv2(out) if self.norm2 is not None: out = self.norm2(out) out += residual # Skip Connection out = F.relu(out) return out
[docs] class ResNet2D(nn.Module): """ A 2D ResNet-style convolutional network for image processing tasks with configurable residual blocks and optional normalization. Parameters ---------- npix : int Spatial size of the input (assumed square). nch_in : int Number of input channels. nch_out : int Number of output channels. nfilts : int Number of filters in convolutional layers. n_res_blocks : int Number of ResNet blocks in the model. norm_type : str or None Type of normalization: 'batch', 'layer', or None. activation : str Final activation function: 'Linear' or 'Sigmoid'. Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, npix, npix). """ def __init__(self, npix, nch_in=1, nch_out=1, nfilts=32, n_res_blocks=4, norm_type='layer', activation='Linear'): super(ResNet2D, self).__init__() self.npix = npix self.n_res_blocks = n_res_blocks layers = [nn.Conv2d(nch_in, nfilts, kernel_size=3, stride=1, padding=1)] if norm_type is not None: layers.append(self.add_norm_layer(nfilts, norm_type)) layers.append(nn.ReLU()) for _ in range(n_res_blocks): layers.append(ResNetBlock(nfilts, npix=npix, norm_type=norm_type)) layers.append(nn.Conv2d(nfilts, nch_out, kernel_size=3, stride=1, padding=1)) if activation == 'Sigmoid': layers.append(nn.Sigmoid()) self.model = nn.Sequential(*layers)
[docs] def add_norm_layer(self, nfilts, norm_type): if norm_type == 'batch': return nn.BatchNorm2d(nfilts) elif norm_type == 'layer': return nn.LayerNorm([nfilts, self.npix, self.npix])
[docs] def forward(self, x, residual=False): out = self.model(x) if residual: out += x return out
[docs] class ConvBlock(nn.Module): """ A basic convolutional block with two Conv2D layers, LayerNorm (by default), and ReLU activation. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. spatial_dims : tuple[int, int] Spatial dimensions of the input (height, width). norm_type : str Normalization type: 'layer' (default) or 'batch'. Returns ------- torch.Tensor Output tensor of shape (batch_size, out_channels, H, W). """ def __init__(self, in_channels, out_channels, spatial_dims, norm_type='layer'): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.norm_type = norm_type self.spatial_dims = spatial_dims self.out_channels = out_channels self.norm_layer1 = nn.LayerNorm([self.out_channels, self.spatial_dims[0], self.spatial_dims[1]]) self.norm_layer2 = nn.LayerNorm([self.out_channels, self.spatial_dims[0], self.spatial_dims[1]])
[docs] def forward(self, x): x = self.conv1(x) x = self.norm_layer1(x) x = F.relu(x) x = self.conv2(x) x = self.norm_layer2(x) x = F.relu(x) return x
[docs] class DownBlock(nn.Module): """ A downsampling block consisting of a Conv2D with stride=2 followed by a ConvBlock. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. spatial_dims : tuple[int, int] Spatial dimensions after downsampling. norm_type : str or None Type of normalization to use in the ConvBlock. Returns ------- torch.Tensor Output tensor of shape (batch_size, out_channels, H/2, W/2). """ def __init__(self, in_channels, out_channels, spatial_dims, norm_type=None): super(DownBlock, self).__init__() self.down_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) self.conv_block = ConvBlock(out_channels, out_channels, spatial_dims, norm_type=norm_type)
[docs] def forward(self, x): x = F.relu(self.down_conv(x)) return self.conv_block(x)
[docs] class UpBlock(nn.Module): """ An upsampling block with transposed convolution and concatenation with a skip connection (bridge), followed by a ConvBlock. Parameters ---------- in_channels : int Number of channels to upsample. bridge_channels : int Number of channels in the bridge tensor (from encoder). out_channels : int Number of output channels after convolution. spatial_dims : tuple[int, int] Spatial dimensions of the output. norm_type : str Normalization type to use in ConvBlock: 'layer' or 'batch'. Returns ------- torch.Tensor Output tensor of shape (batch_size, out_channels, H*2, W*2). """ def __init__(self, in_channels, bridge_channels, out_channels, spatial_dims, norm_type='layer'): super(UpBlock, self).__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) # The combined channels from the upsampled layer and the bridge combined_channels = out_channels + bridge_channels # Use the spatial dimensions for the ConvBlock self.conv_block = ConvBlock(combined_channels, out_channels, spatial_dims, norm_type=norm_type)
[docs] def forward(self, x, bridge): x = self.up(x) # Determine padding for concatenation diffY = bridge.size()[2] - x.size()[2] diffX = bridge.size()[3] - x.size()[3] x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # Concatenate along the channel dimension out = torch.cat([bridge, x], dim=1) return self.conv_block(out)
[docs] class UNet2D(nn.Module): """ A U-Net style 2D convolutional neural network with downsampling and upsampling paths and skip connections. Supports configurable depth and normalization. Parameters ---------- nch_in : int Number of input channels. nch_out : int Number of output channels. npix : int Input spatial dimension (assumes square input). base_nfilts : int Number of filters in the base layer. num_blocks : int Number of downsampling and upsampling blocks. norm_type : str or None Normalization type: 'batch', 'layer', or None. activation : str Final activation function: 'Sigmoid' or 'Linear'. Forward ------- x : torch.Tensor Input tensor of shape (batch_size, nch_in, npix, npix). Returns ------- torch.Tensor Output tensor of shape (batch_size, nch_out, npix, npix). """ def __init__(self, nch_in, nch_out, npix, base_nfilts=64, num_blocks=4, norm_type=None, activation='Linear'): super(UNet2D, self).__init__() spatial_dims = [npix, npix] # Initial spatial dimensions self.initial_conv_block = ConvBlock(nch_in, base_nfilts, spatial_dims, norm_type=norm_type) self.down_blocks = nn.ModuleList() self.bridge_channels = [] for i in range(num_blocks): spatial_dims = [s // 2 for s in spatial_dims] # Halve spatial dimensions after down block self.down_blocks.append(DownBlock(base_nfilts, base_nfilts, spatial_dims, norm_type=norm_type)) self.bridge_channels.append(base_nfilts) self.up_blocks = nn.ModuleList() # Adjust spatial_dims for UpBlocks for i in range(num_blocks-1, -1, -1): bridge_channels = self.bridge_channels[i] self.up_blocks.append(UpBlock(base_nfilts, bridge_channels, base_nfilts, spatial_dims, norm_type=norm_type)) spatial_dims = [s * 2 for s in spatial_dims] # Upsampling increases dimensions self.final_up_block = UpBlock(base_nfilts, bridge_channels, base_nfilts, spatial_dims, norm_type=norm_type) # Final convolution layer self.final_conv = nn.Conv2d(base_nfilts, nch_out, kernel_size=1) self.activation = nn.Sigmoid() if activation == 'Sigmoid' else None
[docs] def forward(self, x): bridges = [] # Initial convolution initial_conv_output = self.initial_conv_block(x) # Downsampling x = initial_conv_output for down_block in self.down_blocks: x = down_block(x) bridges.append(x) bridges = bridges[::-1] # Upsampling for i, up_block in enumerate(self.up_blocks): bridge = bridges[i] x = up_block(x, bridge) # Upsampling for the initial convolution block x = self.final_up_block(x, initial_conv_output) # Final convolution layer x = self.final_conv(x) if self.activation is not None: x = self.activation(x) return x
[docs] class Crop2D(nn.Module): """ Crop a 4D tensor along the height and width dimensions. Parameters ---------- top : int, optional Number of pixels to remove from the top. Default is 0. bottom : int, optional Number of pixels to remove from the bottom. Default is 0. left : int, optional Number of pixels to remove from the left. Default is 0. right : int, optional Number of pixels to remove from the right. Default is 0. Notes ----- - Expects input tensors in NCHW format: (batch, channels, height, width). - If any crop value is 0, that side is left unchanged. - This layer performs a simple slice, without resizing or interpolation. Examples -------- >>> crop = Crop2D(top=1, bottom=2, left=0, right=1) >>> x = torch.randn(2, 3, 10, 10) >>> y = crop(x) >>> y.shape torch.Size([2, 3, 7, 9]) """ def __init__(self, top=0, bottom=0, left=0, right=0): super().__init__() self.top, self.bottom, self.left, self.right = top, bottom, left, right
[docs] def forward(self, x): _, _, H, W = x.shape t, b, l, r = self.top, self.bottom, self.left, self.right return x[:, :, t:(H - b if b > 0 else H), l:(W - r if r > 0 else W)]
[docs] class SD2I(nn.Module): """ SD2I (Single Digit to Image) reconstruction network. This is a PyTorch reimplementation of the Keras/TensorFlow SD2I model. It reconstructs a 2D image from a single scalar input using a sequence of dense layers, optional progressive upsampling, and convolutional refinement. Parameters ---------- npix : int Target output image size (height = width = npix). factor : int, optional Number of feature channels after the dense-to-feature-map reshape. Default is 8. upsample : bool, optional If True, use the progressive upsampling branch (two stages of ×2 upsampling). If False, use the direct full-resolution branch without upsampling. Default is True. Architecture (upsample=True) ---------------------------- 1. Dense stack: [64 → ReLU] × 3, then linear projection to ceil(npix / 4) × ceil(npix / 4) × factor features. 2. Reshape to (factor, H0, W0), where H0 = W0 = ceil(npix / 4). 3. Upsample ×2 (nearest neighbor). - Conditional crop: remove 1 pixel from bottom & right if npix % 4 in {1, 2}. 4. Conv block: [3×3 conv, 64 filters, ReLU] × 3. 5. Upsample ×2 (nearest neighbor). - Conditional crop: remove 1 pixel from top & left if npix is odd. 6. Conv block: [3×3 conv, 64 filters, ReLU] × 3. 7. Final 3×3 conv to 1 channel (linear activation). Architecture (upsample=False) ----------------------------- 1. Dense stack: [128 → ReLU] × 3, then linear projection to npix × npix × factor features. 2. Reshape to (factor, npix, npix). 3. Conv stack: - 3×3 conv, 128 filters, ReLU - 3×3 conv, 128 filters, ReLU - 3×3 conv, 64 filters, ReLU - 3×3 conv, 1 filter, linear Notes ----- - All convolutions use "same" padding (padding=1 for kernel_size=3). - Upsampling uses nearest-neighbor to match Keras `UpSampling2D`. - Weights are initialized with a RandomNormal(mean=0.0, std=0.05) to match Keras `kernel_initializer='random_normal'`. Examples -------- >>> model = SD2I(npix=128, factor=8, upsample=True) >>> inp = torch.randn(4, 1) # batch of scalars >>> out = model(inp) >>> out.shape torch.Size([4, 1, 128, 128]) """ def __init__(self, npix: int, factor: int = 8, upsample: bool = True): super().__init__() self.npix = npix self.factor = factor self.upsample_flag = upsample self.flatten = nn.Flatten() if upsample: h0 = int(math.ceil(npix / 4)) self.fc = nn.Sequential( nn.Linear(1, 64), nn.ReLU(inplace=True), nn.Linear(64, 64), nn.ReLU(inplace=True), nn.Linear(64, 64), nn.ReLU(inplace=True), nn.Linear(64, h0 * h0 * factor), # linear ) self.reshape_h = h0 self.reshape_w = h0 self.up = nn.Upsample(scale_factor=2, mode='nearest') self.crop_after_up1 = Crop2D(top=0, bottom=1, left=0, right=1) \ if (npix % 4 == 1 or npix % 4 == 2) else nn.Identity() self.conv1 = nn.Sequential( nn.Conv2d(factor, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), ) self.crop_after_up2 = Crop2D(top=1, bottom=0, left=1, right=0) \ if (npix % 2 == 1) else nn.Identity() self.conv2 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), ) self.to_out = nn.Conv2d(64, 1, kernel_size=3, padding=1) # linear else: self.fc = nn.Sequential( nn.Linear(1, 128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.ReLU(inplace=True), nn.Linear(128, npix * npix * factor), # linear ) self.reshape_h = npix self.reshape_w = npix self.conv_stack = nn.Sequential( nn.Conv2d(factor, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 1, kernel_size=3, padding=1), # linear ) self._init_random_normal(std=0.05) def _init_random_normal(self, std=0.05): """Initialize Conv2d and Linear layers with N(0, std^2) and zero bias.""" for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.normal_(m.weight, mean=0.0, std=std) if m.bias is not None: nn.init.zeros_(m.bias)
[docs] def forward(self, x): x = self.flatten(x) x = self.fc(x) if self.upsample_flag: B = x.shape[0] x = x.view(B, self.factor, self.reshape_h, self.reshape_w) x = self.up(x) x = self.crop_after_up1(x) x = self.conv1(x) x = self.up(x) x = self.crop_after_up2(x) x = self.conv2(x) x = self.to_out(x) else: B = x.shape[0] x = x.view(B, self.factor, self.reshape_h, self.reshape_w) x = self.conv_stack(x) return x