Writing high-performance GPU kernels directly in CUDA C++ often requires managing complex details like shared memory bank conflicts, warp divergence, and tensor core operand layouts. While this level of control is powerful, it introduces significant engineering overhead. Triton serves as a language and compiler stack that abstracts these complexities while maintaining performance parity with vendor-tuned libraries. It achieves this by exposing a block-based programming model rather than the Single Instruction Multiple Thread (SIMT) model found in CUDA.In this practical section, we will implement a tiled matrix multiplication kernel suitable for deep learning inference. This implementation demonstrates how to map the memory hierarchy concepts discussed previously, specifically coalescing and shared memory management, into Triton's Python-based syntax.The Block-Based Programming ModelUnlike CUDA, where you define the behavior of a single thread, Triton kernels define operations on blocks of data. The compiler automatically handles the partitioning of these blocks onto hardware threads and manages the movement of data through the memory hierarchy (Global Memory $\rightarrow$ Shared Memory $\rightarrow$ Registers).When writing a kernel, you must first define a grid. The grid is a set of instances of your kernel that run in parallel. For matrix multiplication of size $(M, N, K)$, we typically parallelize along the $M$ and $N$ dimensions. Each kernel instance computes a resulting block of size BLOCK_SIZE_M $\times$ BLOCK_SIZE_N.The following graphviz diagram illustrates the data flow within a single Triton program instance (a block). Note how the operations are defined on tiles rather than scalars.digraph G { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica", fontsize=12]; edge [fontname="Helvetica", fontsize=10]; subgraph cluster_global { label="Global Memory (HBM)"; style=filled; color="#e9ecef"; A [label="Matrix A (M x K)", fillcolor="#a5d8ff"]; B [label="Matrix B (K x N)", fillcolor="#a5d8ff"]; C [label="Matrix C (M x N)", fillcolor="#a5d8ff"]; } subgraph cluster_sram { label="On-Chip Memory (SRAM)"; style=filled; color="#e9ecef"; TileA [label="Tile A\n(BLOCK_M x BLOCK_K)", fillcolor="#ffc9c9"]; TileB [label="Tile B\n(BLOCK_K x BLOCK_N)", fillcolor="#ffc9c9"]; Acc [label="Accumulator\n(BLOCK_M x BLOCK_N)", fillcolor="#b2f2bb"]; } TensorCore [label="Tensor Core / FPU\n(Dot Product)", shape=diamond, fillcolor="#ffd8a8"]; A -> TileA [label="tl.load()"]; B -> TileB [label="tl.load()"]; TileA -> TensorCore; TileB -> TensorCore; TensorCore -> Acc [label="+= tl.dot()"]; Acc -> C [label="tl.store()"]; }Data flow execution model for a single Triton program instance processing a block.Defining the Kernel SignatureThe entry point to a Triton kernel is a Python function decorated with @triton.jit. The function signature typically accepts pointers to input tensors, tensor dimensions, and strides. Strides are critical for calculating the memory address of a specific element $(i, j)$ in a flattened 1D memory space, defined by the equation:$$ \text{Address}(i, j) = \text{Base} + i \times \text{stride}_0 + j \times \text{stride}_1 $$Below is the initialization of a matrix multiplication kernel. We compute the 2D offsets for the pointers to matrices A and B.import triton import triton.language as tl @triton.jit def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions M, N, K, # Strides (memory layout) stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters (constants at compile time) BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): # 1. Map the program ID to the block in the output grid pid = tl.program_id(axis=0) # Calculate the number of blocks along the M and N axes num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Determine the row and column index for this specific block pid_m = pid % num_pid_m pid_n = pid // num_pid_m # 2. Create pointers for the first block of A and B # Range of offsets for the M dimension offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M # Range of offsets for the N dimension offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N # Range for the K dimension (reduction axis) offs_k = tl.arange(0, BLOCK_SIZE_K) # Calculate the actual memory addresses # a_ptrs is a matrix of pointers [BLOCK_SIZE_M, BLOCK_SIZE_K] a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # b_ptrs is a matrix of pointers [BLOCK_SIZE_K, BLOCK_SIZE_N] b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # Initialize the accumulator with zeros accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)In the snippet above, tl.arange generates a sequence of integers. By broadcasting these sequences (using [:, None] and [None, :]), we create grids of offsets. When we add these offsets to the base pointer a_ptr, Triton creates a block of pointers. This abstraction allows the compiler to issue vector load instructions suitable for the target architecture.The Reduction LoopThe core computation occurs in a loop that iterates over the $K$ dimension. We advance the pointers by BLOCK_SIZE_K in each iteration. This is analogous to the tiling techniques discussed in Chapter 3, but here the "tile" is handled implicitly by the vector types.Inside the loop, we load blocks of A and B from global memory into SRAM (Shared Memory). The tl.dot operation then performs the matrix multiplication. If the hardware supports it (like NVIDIA Tensor Cores), Triton will automatically lower this high-level operation to the appropriate intrinsics (e.g., mma.sync). # 3. Iterate to compute a block of the C matrix for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B # mask is used to handle boundary conditions if K is not a multiple of BLOCK_SIZE_K # other=0.0 ensures padding with zeros does not affect the sum a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # Accumulate partial dot products accumulator += tl.dot(a, b) # Advance pointers to the next K-block a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bkThe mask argument in tl.load is essential for correctness. It prevents the kernel from reading out-of-bounds memory, which causes illegal memory access errors or silent data corruption. The other=0.0 parameter ensures that masked-out values do not affect the dot product result.Storing Results and CoalescingOnce the reduction loop finishes, the accumulator holds the final values for the tile of $C$. We must store this back to global memory. Efficient stores require memory coalescing, where consecutive threads access consecutive memory addresses. Triton manages the layout of the register file to optimize for this, provided the block sizes and strides are aligned. # 4. Store the result # Create pointers for C offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] # Write to global memory, applying a mask to avoid out-of-bounds writes c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask)Performance and Block SizingThe choice of BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K heavily influences performance. Large blocks maximize data reuse in the L1 cache/shared memory, reducing the demand on global memory bandwidth. However, blocks that are too large may exceed the available shared memory or register file size, reducing occupancy (the number of active warps).The chart below displays a comparison of throughput (TFLOPS) for different block configurations on a standard GPU architecture. Notice how specific configurations yield significantly better performance due to alignment with hardware specificities (e.g., 128-byte cache lines).{ "layout": { "title": "Throughput vs. Block Configuration (Float16)", "xaxis": { "title": "Block Size Configuration (M, N, K)" }, "yaxis": { "title": "Throughput (TFLOPS)" }, "plot_bgcolor": "#f8f9fa", "paper_bgcolor": "#ffffff", "font": { "family": "Helvetica" } }, "data": [ { "type": "bar", "x": ["32x32x32", "64x64x32", "128x128x32", "128x256x64", "256x256x64"], "y": [25.5, 68.2, 92.4, 88.1, 45.3], "marker": { "color": ["#a5d8ff", "#74c0fc", "#339af0", "#228be6", "#1c7ed6"] } } ] }Throughput variance across different block size configurations, highlighting the importance of tuning.The configuration 128x128x32 often performs well because it balances tile size with register pressure. However, manual selection is inefficient. Triton provides an auto-tuning mechanism using the @triton.autotune decorator. This feature compiles multiple versions of the kernel with different meta-parameters and selects the fastest one at runtime based on the actual input shapes.@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), ], key=['M', 'N', 'K'], ) @triton.jit def matmul_kernel(...): # Kernel implementation...By defining num_stages, you control software pipelining. This optimization, discussed in the loop scheduling chapter, allows the GPU to load data for iteration $i+1$ while computing iteration $i$, hiding global memory latency.Kernel LaunchTo execute the kernel, we calculate the grid dimensions based on the input matrix size and the chosen block size. The host-side Python code invokes the JIT-compiled kernel.def matmul(a, b): M, K = a.shape K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=a.dtype) # Grid definition function grid = lambda META: ( triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) # Launch matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1) ) return cThis implementation provides a foundation for hardware-aware code generation. By manipulating block pointers and leveraging tl.dot, you effectively control the GPU's memory hierarchy and compute units without manually managing synchronization barriers or bank conflicts, which the compiler resolves during the lowering phase to LLVM IR and PTX.