Okay, let's put theory into practice by building a simple CUDA extension. This exercise will guide you through creating a custom CUDA kernel for a basic operation (scaled vector addition), writing the necessary C++ bindings, compiling it, and finally calling it from Python using PyTorch tensors. This process demonstrates the fundamental steps involved in accelerating specific computations on the GPU beyond what standard PyTorch operations might offer efficiently.
Our goal is to implement a function scaled_add(alpha, x, y)
that computes z=α∗x+y, where α is a scalar, and x,y,z are vectors (1D tensors). We will write the core computation as a CUDA kernel and integrate it as a PyTorch C++ extension.
Ensure you have the following installed and configured:
torch.cuda.is_available()
should return True
).nvcc
compiler must be in your system's PATH.Let's organize our code. Create a directory structure like this:
simple_cuda_extension/
├── setup.py
└── src/
├── scaled_add.cpp
└── scaled_add_kernel.cu
scaled_add_kernel.cu
)This file contains the actual GPU code. We define a CUDA kernel function that performs the scaled addition element-wise.
// src/scaled_add_kernel.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h> // For CUDA math functions if needed, though not strictly for this example
// CUDA Kernel definition
// Computes z = alpha * x + y for each element
__global__ void scaled_add_kernel(const float* x, const float* y, float* z, float alpha, int N) {
// Calculate the global thread index
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; // Total number of threads in the grid
// Use grid-stride loop to ensure all elements are processed
// even if N > number of threads launched.
for (int i = index; i < N; i += stride) {
z[i] = alpha * x[i] + y[i];
}
}
// C++ wrapper function (optional but good practice)
// This can be called from the main C++ binding code.
// It sets up the kernel launch configuration.
void scaled_add_cuda_launcher(const float* x, const float* y, float* z, float alpha, int N) {
// Define block and grid dimensions
// Generally, choose block size as a multiple of 32 (warp size)
// Common choices are 128, 256, 512, 1024
int blockSize = 256;
// Calculate grid size needed to cover all N elements
// Equivalent to ceil(N / blockSize)
int gridSize = (N + blockSize - 1) / blockSize;
// Launch the kernel
scaled_add_kernel<<<gridSize, blockSize>>>(x, y, z, alpha, N);
// Optional: Check for kernel launch errors (useful for debugging)
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
// Consider throwing an exception here in a real application
}
// Optional: Synchronize device (wait for kernel to finish) if needed immediately
// cudaDeviceSynchronize(); // Usually not needed if subsequent operations use the same stream
}
Explanation:
__global__ void scaled_add_kernel(...)
: Defines a function that runs on the GPU.blockIdx.x
, blockDim.x
, threadIdx.x
, gridDim.x
: Built-in CUDA variables that give each thread its unique ID and context within the grid of threads launched.index = blockIdx.x * blockDim.x + threadIdx.x
: Calculates a unique global index for each thread.stride = gridDim.x * blockDim.x
: The total number of threads in the grid.for
loop (for (int i = index; i < N; i += stride)
) is essential. It allows a fixed number of threads (potentially fewer than N
) to process all N
elements by having each thread process multiple elements spaced stride
apart. This is more robust than assuming N
is perfectly divisible by the block size or that the grid size exactly matches N / blockSize
.scaled_add_cuda_launcher
: A helper C++ function to configure and launch the kernel. It calculates the number of blocks (gridSize
) needed based on the input size N
and a chosen block size (blockSize
). <<<gridSize, blockSize>>>
is the CUDA syntax for launching the kernel.scaled_add.cpp
)This file interfaces the CUDA code with PyTorch. It defines a function callable from Python, handles tensor data access, and calls the CUDA launcher.
// src/scaled_add.cpp
#include <torch/extension.h>
#include <vector>
// Forward declaration of the CUDA launcher function from scaled_add_kernel.cu
void scaled_add_cuda_launcher(const float* x, const float* y, float* z, float alpha, int N);
// C++ interface function that will be bound to Python
// It accepts PyTorch tensors as input
torch::Tensor scaled_add(torch::Tensor x, torch::Tensor y, float alpha) {
// Input validation: Ensure tensors are on the GPU and have the same shape/dtype
TORCH_CHECK(x.device().is_cuda(), "Input tensor x must be a CUDA tensor");
TORCH_CHECK(y.device().is_cuda(), "Input tensor y must be a CUDA tensor");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "Input tensor x must be float32");
TORCH_CHECK(y.scalar_type() == torch::kFloat32, "Input tensor y must be float32");
TORCH_CHECK(x.is_contiguous(), "Input tensor x must be contiguous");
TORCH_CHECK(y.is_contiguous(), "Input tensor y must be contiguous");
TORCH_CHECK(x.sizes() == y.sizes(), "Input tensors x and y must have the same shape");
TORCH_CHECK(x.dim() == 1, "Input tensor x must be 1D"); // Simple check for this example
TORCH_CHECK(y.dim() == 1, "Input tensor y must be 1D"); // Simple check for this example
// Get the number of elements
int N = x.numel();
// Create the output tensor (on the same device as inputs)
auto z = torch::empty_like(x); // Creates tensor with same shape, dtype, device
// Get raw data pointers
// .data_ptr<float>() gives access to the underlying C++ float* data
const float* x_ptr = x.data_ptr<float>();
const float* y_ptr = y.data_ptr<float>();
float* z_ptr = z.data_ptr<float>();
// Call the CUDA kernel launcher function defined in the .cu file
scaled_add_cuda_launcher(x_ptr, y_ptr, z_ptr, alpha, N);
return z;
}
// Binding code using PYBIND11_MODULE macro
// This creates the Python module named 'simple_cuda_extension_cpp'
// The second argument 'm' is the module object
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Expose the C++ 'scaled_add' function to Python as 'scaled_add'
m.def("scaled_add", &scaled_add, "Scaled vector addition (alpha * x + y) computed on CUDA");
}
Explanation:
#include <torch/extension.h>
: The primary header for PyTorch C++ extensions.Forward declaration
: We declare scaled_add_cuda_launcher
so the compiler knows about it before it's used. The actual implementation is in the .cu
file and will be linked later.scaled_add(torch::Tensor x, torch::Tensor y, float alpha)
: The function exposed to Python. It takes PyTorch tensors and a float.TORCH_CHECK(...)
: PyTorch's assertion macro. It checks conditions and throws informative C++ exceptions (which get translated to Python exceptions) if they fail. We validate device, data type, contiguity, shape, and dimensionality. Contiguity is important because CUDA kernels often assume data is laid out sequentially in memory.torch::empty_like(x)
: Creates an output tensor z
with the same properties (size, dtype, device) as x
, but without initializing the memory contents..data_ptr<float>()
: Gets the raw C-style pointer to the tensor's underlying data buffer. This is needed to pass to the CUDA kernel.scaled_add_cuda_launcher(...)
: Calls the function defined in our .cu
file.PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
: This macro (provided by torch/extension.h
, which includes pybind11) creates the entry point for the Python module. TORCH_EXTENSION_NAME
is a placeholder that will be replaced by the module name specified in setup.py
.m.def("scaled_add", ...)
: Binds the C++ function scaled_add
to the Python name scaled_add
within the module m
. The string is the docstring for the Python function.setup.py
)This script uses Python's setuptools
and PyTorch's utilities to compile the C++ and CUDA code into a Python extension module.
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='simple_cuda_extension_cpp', # Package name, can be anything
ext_modules=[
CUDAExtension(
name='simple_cuda_extension_cpp', # Python module name users will import
sources=[
'src/scaled_add.cpp',
'src/scaled_add_kernel.cu',
]
)
],
cmdclass={
'build_ext': BuildExtension
}
)
Explanation:
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
: Imports necessary build utilities from PyTorch.CUDAExtension(...)
: Specifies that we are building an extension involving CUDA code.
name
: The name of the resulting Python module (e.g., import simple_cuda_extension_cpp
). This must match the TORCH_EXTENSION_NAME
placeholder used in the PYBIND11_MODULE
macro internally.sources
: A list of all source files (.cpp
and .cu
) needed for the extension.cmdclass={'build_ext': BuildExtension}
: Tells setuptools
to use PyTorch's custom build command, which knows how to handle CUDA compilation (nvcc
) and linking with PyTorch libraries.Navigate to the simple_cuda_extension
directory in your terminal (the one containing setup.py
) and run the build command:
# Option 1: Build and install into your Python environment
python setup.py install
# Option 2: Build in place (creates the .so or .pyd file in the current directory)
# Useful for development
python setup.py build_ext --inplace
If successful, this command will invoke the C++ compiler and nvcc
to compile your code and link it against PyTorch libraries, producing a shared object file (e.g., simple_cuda_extension_cpp.cpython-39-x86_64-linux-gnu.so
on Linux or simple_cuda_extension_cpp.pyd
on Windows) that Python can import.
Now you can import and use your custom CUDA function just like any other Python module.
# test_extension.py (place this outside the simple_cuda_extension directory, or after installing)
import torch
import time
# Try to import the compiled extension
try:
import simple_cuda_extension_cpp
print("Successfully imported CUDA extension.")
except ImportError:
print("Error importing CUDA extension. Did you compile it successfully?")
print("Run: python setup.py build_ext --inplace (from the extension directory)")
exit()
# Define input tensors on the CPU first
N = 1024 * 1024 # Size of vectors
alpha = 2.5
x_cpu = torch.randn(N, dtype=torch.float32)
y_cpu = torch.randn(N, dtype=torch.float32)
# Move tensors to the GPU
if torch.cuda.is_available():
device = torch.device('cuda')
x_gpu = x_cpu.to(device)
y_gpu = y_cpu.to(device)
print(f"Using device: {device}")
else:
print("CUDA not available. Exiting.")
exit()
# Ensure inputs are contiguous (important for .data_ptr())
x_gpu = x_gpu.contiguous()
y_gpu = y_gpu.contiguous()
# --- Using the Custom CUDA Extension ---
print("\nTesting Custom CUDA Extension:")
# Warm-up GPU
_ = simple_cuda_extension_cpp.scaled_add(alpha, x_gpu, y_gpu)
torch.cuda.synchronize() # Wait for warm-up to finish
start_time = time.time()
z_gpu_custom = simple_cuda_extension_cpp.scaled_add(alpha, x_gpu, y_gpu)
torch.cuda.synchronize() # Wait for the kernel to complete before stopping timer
end_time = time.time()
print(f"Custom CUDA extension time: {(end_time - start_time)*1000:.4f} ms")
# --- Using Standard PyTorch Operations for Verification ---
print("\nTesting Standard PyTorch Operations:")
# Warm-up GPU
_ = alpha * x_gpu + y_gpu
torch.cuda.synchronize()
start_time = time.time()
z_gpu_pytorch = alpha * x_gpu + y_gpu
torch.cuda.synchronize()
end_time = time.time()
print(f"Standard PyTorch time: {(end_time - start_time)*1000:.4f} ms")
# --- Verification ---
# Check if the results are close (allowing for floating-point differences)
difference = torch.abs(z_gpu_custom - z_gpu_pytorch).mean()
print(f"\nMean absolute difference between custom and PyTorch results: {difference.item()}")
if torch.allclose(z_gpu_custom, z_gpu_pytorch, atol=1e-6):
print("Results match!")
else:
print("Results DO NOT match!")
# Example: Print first few elements if needed
# print("Custom output (first 10):", z_gpu_custom[:10])
# print("PyTorch output (first 10):", z_gpu_pytorch[:10])
Running the Test:
Save the Python code above (e.g., as test_extension.py
) and run it: python test_extension.py
.
You should see output indicating whether the import was successful, the execution times for both your custom kernel and the standard PyTorch operation, and a check confirming that the results are numerically very close. For simple operations like this on modern GPUs, standard PyTorch operations are highly optimized, so don't be surprised if the PyTorch version is faster or comparable. The benefit of custom extensions becomes more apparent for complex, non-standard operations or sequences of operations that can be fused into a single kernel.
This practical exercise demonstrated the end-to-end process of creating a PyTorch CUDA extension:
.cu
)..cpp
).pybind11
(.cpp
).setuptools
and torch.utils.cpp_extension
to compile the CUDA and C++ code (setup.py
).While this example was basic, it establishes the fundamental workflow. Real-world extensions often involve more sophisticated kernels, potentially handling different data types, multiple dimensions, and require defining custom backward passes if autograd support is needed (refer back to Chapter 1 on Custom Autograd Functions). Building extensions requires careful attention to memory management, data types, device placement, and synchronization, but it provides a powerful way to optimize performance-critical sections of your PyTorch models.
© 2025 ApX Machine Learning