torch_nd_conv contains a fully Python written implementation of n-dimensional convolution operation using PyTorch as its only dependency.
- Arbitrary Dimension Convolution: implements convolution (
Conv) operation for input spaces any number of dimensions. - Auxiliary Classes: provides
FoldandUnfoldclasses - instrumental for the manipulation of N-dimensional feature spaces. - PyTorch Integration: all classes inherit from
torch.nn.Module- ensuring seamless integration with existing PyTorch workflows. - Python 3.12+: developed using Python version 3.12 - taking advantage of some its new features and optimizations.
While the modules provided in this repository are built upon PyTorch's architecture and conventions, there are some key differences in their usage compared to PyTorch's native modules. Below are important considerations and guidelines for effectively utilizing the Fold, Unfold, and Conv classes.
Both Fold and Unfold accept optional arguments that precompute shape-related data during construction, reducing work in each forward call:
-
input_size(Fold,Unfold)- Purpose: If you know the non-batched spatial dimensions (e.g.,
(C, D, H, W)for a batch of 3D volumes), passinginput_sizelets the module build its index masks and compute dilations/strides once in__init__. This avoids repeating those calculations at runtime. - Usage: Provide a tuple (or integer) matching the input’s convolutional dimensions. The module checks at
forwardthat the incoming tensor’s shape aligns with this size.
- Purpose: If you know the non-batched spatial dimensions (e.g.,
-
output_size(Foldonly)- Purpose: Validates at construction that your kernel parameters (kernel size, stride, dilation, padding) will fold an unfolded tensor back to the correct shape, catching mismatches early.
- Usage: Supply the expected output-volume shape (excluding batch), e.g.
(C, D′, H′, W′). If the parameters wouldn’t produce that size,Foldraises an error in__init__.
-
kernel_position(Foldonly)-
Purpose: Lets you specify whether the kernel dimensions come before or after the convolutional-output axes in the unfolded input:
"last"(default): Input toforwardis(..., D_out, H_out, W_out, K_d, K_h, K_w)."first": Input toforwardis(..., K_d, K_h, K_w, D_out, H_out, W_out).
-
Usage: Match it to how your data pipeline organizes those axes so you don’t need extra
permutecalls.
-
Unlike PyTorch’s 2D-only modules, which flatten all kernel dims into one channel axis, torch_nd_conv keeps kernel dimensions separate for clarity in N-D operations:
-
Unfold-
PyTorch 2D: Takes
(N, C, H, W)→(N, C×K_h×K_w, L), collapsing theK_h×K_wpatch into a single channel and listingLsliding-window positions. -
torch_nd_conv N-D: From
(N, C, D, H, W)with kernel(K_d, K_h, K_w), returns(N, C, D_out, H_out, W_out, K_d, K_h, K_w).- Convolutional output axes
(D_out, H_out, W_out)remain distinct. - Kernel dims
(K_d, K_h, K_w)stay separate, so each patch element’s location is obvious.
- Convolutional output axes
-
-
Fold-
PyTorch 2D: Expects
(N, C×K_h×K_w, L)→ reconstructs(N, C, H, W)by summing overlaps. -
torch_nd_conv N-D: Takes
(N, C, D_out, H_out, W_out, K_d, K_h, K_w)→ reconstructs(N, C, D′, H′, W′).- Gathers each of the
K_d×K_h×K_welements from their(D_out, H_out, W_out)positions. - Sums them along the reconstruction axes to rebuild the original volume.
- Gathers each of the
-
By preserving kernel dimensions, torch_nd_conv makes it straightforward to generalize beyond 2D. No manual reshaping or axis permutation is needed when moving to 3D, 4D, or higher.
Below is a basic example demonstrating how to utilize the Conv module alongside Fold and Unfold for a 3D convolution operation:
import torch
from torch_nd_conv import Conv, Fold, Unfold
# Define input dimensions: (batch_size, channels, depth, height, width)
input_tensor = torch.randn(8, 3, 8, 16, 16)
# Initialize FoldND and UnfoldND
fold = Fold(kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), kernel_position="last")
unfold = Unfold(kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# Initialize ConvND
conv = Conv(input_channels=3, output_channels=2, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# Perform unfold
unfolded = unfold(input_tensor)
# Perform fold
folded_output = fold(unfolded)
# Perform convolution
output = conv(unfolded)To get started with the ConvND repository, follow these steps to set up your development environment:
-
Clone the Repository
git clone https://github.com/mntsx/torch_nd_conv.git
-
Install Dependencies
It's recommended to use a virtual environment to manage dependencies.
-
Windows:
py -3.12 -m venv .venv .venv\Scripts\activate python -m pip install --upgrade pip cd torch_nd_conv python -m pip install -r requirements.txt
-
macOS/Linux:
python3.12 venv .venv source .venv/bin/activate python3 -m pip install --upgrade pip cd torch_nd_conv python3 -m pip install -r requirements.txt
-
To evaluate the performance of the custom n-dimensional convolution against PyTorch's native convolution functions, execute the benchmark.conv submodule:
cd torch_nd_conv
python -m benchmarks.convThis will output the execution times and performance ratios for both 2D and 3D convolution operations.
Ensure that all modules are functioning correctly by running the test suites using pytest:
cd torch_nd_conv
pytest .The repository is organized aiming for functional separation. Below is an overview of the primary directories and their respective contents:
torch_nd_conv/
│
├── benchmarks/
│ ├── conv.py
│ └── __init__.py
│
├── src/
│ ├── conv.py
│ ├── fold.py
│ ├── internal_types.py
│ ├── utils.py
│ └── __init__.py
│
├── tests/
│ ├── pytest.ini
│ ├── test_conv.py
│ ├── test_fold.py
│ ├── test_unfold.py
│ └── __init__.py
│
├── __init__.py
├── .gitignore
├── pytest.ini
├── requirements.txt
└── README.md
This directory contains benchmarking scripts that compare the performance of the custom n-dimensional convolution functions against PyTorch's native convolution operations in the dimensions where PyTorch provides built-in support.
conv.py: Contains benchmarks that evaluate the performance of the n-dimensional convolution (Conv) against PyTorch's 2D (conv2d) and 3D (conv3d) convolution functions.
The core implementations of the convolution operations and their auxiliary classes reside in this directory.
conv.py: Defines theConvmodule, implementing the n-dimensional convolution operation as a subclass oftorch.nn.Module.fold.py: Contains the definitions for the n-dimensionalFoldandUnfoldclasses, which are essential for preparing and reconstructing data for convolution operations in arbitrary dimensions.internal_types.py: Includes custom type definitions that enhance readability and maintainability through improved type hinting.utils.py: Provides utility functions used for validating hyperparameters and inputs for theFold,Unfold, andConvclasses.__init__.py: Initializes thesrcpackage, facilitating easy imports of the modules within.
This directory houses test suites that ensure the correctness and reliability of the implemented modules using pytest.
test_conv.py: Contains tests verifying the functionality and performance of theConvmodule.test_fold.py: Includes tests for theFoldclass, ensuring accurate data folding operations.test_unfold.py: Comprises tests for theUnfoldclass, validating the unfolding process.__init__.py: Initializes thetestspackage, enabling straightforward test discovery and execution.