While custom C++ extensions offer a way to integrate CPU-bound logic or external C++ libraries, performance bottlenecks in deep learning often reside on the GPU. When standard PyTorch operations don't suffice, or you need to implement a novel algorithm with maximum GPU efficiency, writing custom CUDA kernels becomes necessary. This section details how to create, build, and integrate custom CUDA C++ code directly into your PyTorch workflows.
The primary motivation for CUDA extensions is performance. You might have a specific mathematical operation, a data manipulation routine, or a kernel from existing CUDA research code that you want to execute directly on the GPU without the overhead of transferring data back and forth to the CPU or relying on potentially suboptimal sequences of standard PyTorch operations.
Integrating custom CUDA code involves several steps, similar to C++ extensions but with the added complexity of GPU programming:
.cu
file). This involves writing functions (__global__
for kernels launched from the host, __device__
for functions called from the GPU) that operate on data residing in GPU memory..cpp
file) that acts as an interface between PyTorch and your CUDA kernel. This wrapper will:
tensor.data_ptr()
) for GPU memory access.torch.utils.cpp_extension
) to compile the CUDA kernel and C++ wrapper and make the wrapped function callable from Python. This can be done Just-In-Time (JIT) or via a setup.py
script.Let's illustrate this with a simple vector addition kernel.
1. CUDA Kernel (vector_add_kernel.cu
)
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h> // For printf in kernel if needed for debugging
// CUDA kernel for element-wise vector addition
__global__ void vector_add_kernel(const float* a, const float* b, float* c, int n) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
for (int i = index; i < n; i += stride) {
c[i] = a[i] + b[i];
}
}
// A simple function to call the kernel (could be more complex)
// Note: Error checking (cudaGetLastError) is omitted for brevity but crucial in production.
void vector_add_cuda_launcher(const float* a, const float* b, float* c, int n) {
int threadsPerBlock = 256;
// Use integer ceil division
int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
// Launch the kernel
vector_add_kernel<<<blocksPerGrid, threadsPerBlock>>>(a, b, c, n);
// Optional: Synchronize device after kernel launch if needed immediately
// cudaDeviceSynchronize(); // Be mindful of performance impact
}
2. C++ Wrapper (vector_add.cpp
)
This file bridges PyTorch and our CUDA launcher function.
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations (assuming vector_add_cuda_launcher is defined elsewhere, e.g., in the .cu file or a header)
void vector_add_cuda_launcher(const float* a, const float* b, float* c, int n);
// C++ interface function adhering to PyTorch's C++ API
// Note: AT_ASSERT macros ensure tensors are on the correct device and have the expected type/shape.
torch::Tensor vector_add(torch::Tensor a, torch::Tensor b) {
// Input validation
TORCH_CHECK(a.device().is_cuda(), "Input tensor a must be a CUDA tensor");
TORCH_CHECK(b.device().is_cuda(), "Input tensor b must be a CUDA tensor");
TORCH_CHECK(a.is_contiguous(), "Input tensor a must be contiguous");
TORCH_CHECK(b.is_contiguous(), "Input tensor b must be contiguous");
TORCH_CHECK(a.dtype() == torch::kFloat32, "Input tensor a must be float32");
TORCH_CHECK(b.dtype() == torch::kFloat32, "Input tensor b must be float32");
TORCH_CHECK(a.sizes() == b.sizes(), "Input tensors must have the same shape");
// Create the output tensor on the same device as input
torch::Tensor c = torch::empty_like(a);
int n = a.numel(); // Total number of elements
// Call the CUDA launcher function
vector_add_cuda_launcher(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
n);
return c;
}
// Binding function: Expose 'vector_add' C++ function as 'vector_add_cuda' in Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &vector_add, "CUDA vector addition forward");
// If you have a backward pass, you'd bind it here too.
}
3. Python Binding and Build (using JIT)
The easiest way to compile and load for simple cases is using torch.utils.cpp_extension.load
.
import torch
import time
from torch.utils.cpp_extension import load
# Load the CUDA extension, JIT compiling it if necessary
# 'verbose=True' shows the compilation commands
vector_add_module = load(
name='vector_add_cuda',
sources=['vector_add.cpp', 'vector_add_kernel.cu'],
verbose=True
)
# Prepare input tensors on the GPU
device = torch.device('cuda')
size = 10000000 # Large vector size
a = torch.randn(size, device=device, dtype=torch.float32)
b = torch.randn(size, device=device, dtype=torch.float32)
# --- Using the PyTorch default add ---
start_time = time.time()
c_pytorch = a + b
torch.cuda.synchronize() # Wait for GPU operations to complete
pytorch_time = time.time() - start_time
print(f"PyTorch default add time: {pytorch_time:.6f} seconds")
# --- Using the custom CUDA extension ---
start_time = time.time()
c_cuda = vector_add_module.forward(a, b)
torch.cuda.synchronize() # Wait for GPU operations to complete
cuda_time = time.time() - start_time
print(f"Custom CUDA add time: {cuda_time:.6f} seconds")
# Verify results (allowing for small floating-point differences)
diff = torch.abs(c_pytorch - c_cuda).max()
print(f"Maximum difference between PyTorch and CUDA results: {diff.item()}")
assert torch.allclose(c_pytorch, c_cuda, atol=1e-6), "Results differ significantly!"
print("CUDA extension test passed!")
This example uses the JIT compiler. For larger projects or distribution, you would typically use a setup.py
file:
Example setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='vector_add_cuda',
ext_modules=[
CUDAExtension('vector_add_cuda', [ # Module name must match PYBIND11_MODULE
'vector_add.cpp',
'vector_add_kernel.cu',
]),
],
cmdclass={
'build_ext': BuildExtension
})
You would then build this using python setup.py install
. After installation, you can import it like a regular Python module: import vector_add_cuda
.
Build process for a PyTorch CUDA extension. Python tools orchestrate the compilation of C++ and CUDA code via system compilers (like g++ and NVCC) into a loadable shared library.
float
, half
). The C++ wrapper must handle type checking and potentially dispatch to different kernel versions or perform type conversions. Using templates in both C++ and CUDA can help manage this.TORCH_CHECK(tensor.is_contiguous(), ...)
assertion is important. If a tensor isn't contiguous, you might need to call .contiguous()
in the Python code or handle non-contiguous memory access carefully within the kernel (which is generally much less efficient).tensor.device().is_cuda()
) and that the output tensor is created on the same device.gridDim
(number of blocks) and blockDim
(threads per block) can significantly impact performance and depends on the specific GPU architecture and kernel logic. This often requires experimentation.torch.cuda.synchronize()
waits for all preceding CUDA operations on the current stream to complete. Using it excessively, however, can hurt performance. Often, synchronization is implicitly handled when data is copied back to the CPU or used by another PyTorch CUDA operation.cudaGetLastError()
after kernel launches and other CUDA API calls within your C++/CUDA code to catch runtime errors. These errors can be propagated back to Python as exceptions.setup.py
provides better control for complex builds, linking external libraries, and distribution.backward
function (often requiring another custom CUDA kernel) and bind it using torch::autograd::Function
, similar to custom C++ autograd functions discussed in Chapter 1, but with CUDA kernels handling the computation.Building custom CUDA extensions requires familiarity with CUDA C++ programming alongside PyTorch's C++ API. However, it provides the ultimate control over GPU execution, enabling significant performance gains for specialized, compute-intensive operations within your deep learning models.
© 2025 ApX Machine Learning