Writing high-performance GPU kernels directly in CUDA C++ often requires managing complex details like shared memory bank conflicts, warp divergence, and tensor core operand layouts. While this level of control is powerful, it introduces significant engineering overhead. Triton serves as a language and compiler stack that abstracts these complexities while maintaining performance parity with vendor-tuned libraries. It achieves this by exposing a block-based programming model rather than the Single Instruction Multiple Thread (SIMT) model found in CUDA.
In this practical section, we will implement a tiled matrix multiplication kernel suitable for deep learning inference. This implementation demonstrates how to map the memory hierarchy concepts discussed previously, specifically coalescing and shared memory management, into Triton's Python-based syntax.
Unlike CUDA, where you define the behavior of a single thread, Triton kernels define operations on blocks of data. The compiler automatically handles the partitioning of these blocks onto hardware threads and manages the movement of data through the memory hierarchy (Global Memory Shared Memory Registers).
When writing a kernel, you must first define a grid. The grid is a set of instances of your kernel that run in parallel. For matrix multiplication of size , we typically parallelize along the and dimensions. Each kernel instance computes a resulting block of size BLOCK_SIZE_M BLOCK_SIZE_N.
The following graphviz diagram illustrates the data flow within a single Triton program instance (a block). Note how the operations are defined on tiles rather than scalars.
Data flow execution model for a single Triton program instance processing a block.
The entry point to a Triton kernel is a Python function decorated with @triton.jit. The function signature typically accepts pointers to input tensors, tensor dimensions, and strides. Strides are critical for calculating the memory address of a specific element in a flattened 1D memory space, defined by the equation:
Below is the initialization of a matrix multiplication kernel. We compute the 2D offsets for the pointers to matrices A and B.
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# Strides (memory layout)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters (constants at compile time)
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
# 1. Map the program ID to the block in the output grid
pid = tl.program_id(axis=0)
# Calculate the number of blocks along the M and N axes
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Determine the row and column index for this specific block
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# 2. Create pointers for the first block of A and B
# Range of offsets for the M dimension
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
# Range of offsets for the N dimension
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# Range for the K dimension (reduction axis)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# Calculate the actual memory addresses
# a_ptrs is a matrix of pointers [BLOCK_SIZE_M, BLOCK_SIZE_K]
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
# b_ptrs is a matrix of pointers [BLOCK_SIZE_K, BLOCK_SIZE_N]
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# Initialize the accumulator with zeros
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
In the snippet above, tl.arange generates a sequence of integers. By broadcasting these sequences (using [:, None] and [None, :]), we create grids of offsets. When we add these offsets to the base pointer a_ptr, Triton creates a block of pointers. This abstraction allows the compiler to issue vector load instructions suitable for the target architecture.
The core computation occurs in a loop that iterates over the dimension. We advance the pointers by BLOCK_SIZE_K in each iteration. This is analogous to the tiling techniques discussed in Chapter 3, but here the "tile" is handled implicitly by the vector types.
Inside the loop, we load blocks of A and B from global memory into SRAM (Shared Memory). The tl.dot operation then performs the matrix multiplication. If the hardware supports it (like NVIDIA Tensor Cores), Triton will automatically lower this high-level operation to the appropriate intrinsics (e.g., mma.sync).
# 3. Iterate to compute a block of the C matrix
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B
# mask is used to handle boundary conditions if K is not a multiple of BLOCK_SIZE_K
# other=0.0 ensures padding with zeros does not affect the sum
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# Accumulate partial dot products
accumulator += tl.dot(a, b)
# Advance pointers to the next K-block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
The mask argument in tl.load is essential for correctness. It prevents the kernel from reading out-of-bounds memory, which causes illegal memory access errors or silent data corruption. The other=0.0 parameter ensures that masked-out values do not affect the dot product result.
Once the reduction loop finishes, the accumulator holds the final values for the tile of . We must store this back to global memory. Efficient stores require memory coalescing, where consecutive threads access consecutive memory addresses. Triton manages the layout of the register file to optimize for this, provided the block sizes and strides are aligned.
# 4. Store the result
# Create pointers for C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
# Write to global memory, applying a mask to avoid out-of-bounds writes
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
The choice of BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K heavily influences performance. Large blocks maximize data reuse in the L1 cache/shared memory, reducing the demand on global memory bandwidth. However, blocks that are too large may exceed the available shared memory or register file size, reducing occupancy (the number of active warps).
The chart below displays a comparison of throughput (TFLOPS) for different block configurations on a standard GPU architecture. Notice how specific configurations yield significantly better performance due to alignment with hardware specificities (e.g., 128-byte cache lines).
Throughput variance across different block size configurations, highlighting the importance of tuning.
The configuration 128x128x32 often performs well because it balances tile size with register pressure. However, manual selection is inefficient. Triton provides an auto-tuning mechanism using the @triton.autotune decorator. This feature compiles multiple versions of the kernel with different meta-parameters and selects the fastest one at runtime based on the actual input shapes.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...):
# Kernel implementation...
By defining num_stages, you control software pipelining. This optimization, discussed in the loop scheduling chapter, allows the GPU to load data for iteration while computing iteration , hiding global memory latency.
To execute the kernel, we calculate the grid dimensions based on the input matrix size and the chosen block size. The host-side Python code invokes the JIT-compiled kernel.
def matmul(a, b):
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# Grid definition function
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
# Launch
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1)
)
return c
This implementation provides a foundation for hardware-aware code generation. By manipulating block pointers and leveraging tl.dot, you effectively control the GPU's memory hierarchy and compute units without manually managing synchronization barriers or bank conflicts, which the compiler resolves during the lowering phase to LLVM IR and PTX.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with