A core function of a modern ML runtime system is managing tensors whose dimensions are not fully known before execution begins. Unlike Ahead-of-Time (AOT) compilation scenarios where tensor shapes are often fixed (e.g., 224×224 images, batch size 32), many real-world applications involve dynamic shapes. Examples include natural language processing models handling variable-length sequences, object detection models processing images of different resolutions, or inference servers processing requests with varying batch sizes. This dynamism presents significant challenges compared to static-shape execution, primarily impacting memory management, kernel selection, and overall performance predictability.
Consider a tensor with shape N×C×H×W. In a static scenario, N,C,H,W are constants known at compile time. This allows the compiler and runtime to pre-plan memory allocation precisely, select or generate highly optimized kernels for these exact dimensions, and schedule operations efficiently. However, if one or more dimensions are symbolic or unknown until runtime (e.g., N×C×?×?), the runtime must employ specific strategies to handle this uncertainty.
One approach is for the runtime system to perform shape inference dynamically. As operations in the computation graph execute, the runtime can propagate concrete shape information. For instance, if an operation takes two input tensors with known shapes A×B and B×C and produces an output, the runtime can deduce the output shape will be A×C. This relies on the runtime having access to the graph structure and the shape transfer functions for each operator. Often, the Intermediate Representation (IR) used by the compiler (like MLIR) encodes these shape functions or provides mechanisms to represent unknown dimensions (often denoted as ?
or -1
). While effective for many cases, dynamic shape inference might not resolve all ambiguities, especially with complex control flow or operations whose output shape depends on input values (not just shapes).
When exact shapes become known just before an operation executes, the runtime can select a pre-compiled kernel optimized for those specific dimensions or even trigger a Just-In-Time (JIT) compilation process to generate a specialized kernel on the fly. This is particularly relevant for compute-intensive operations like convolutions or matrix multiplications where performance is highly sensitive to dimensions.
A simplified flow showing runtime shape checking and kernel selection/JIT compilation.
This strategy maintains high performance for the actual computation but introduces potential latency. JIT compilation takes time, and managing a cache of specialized kernels adds complexity. Techniques like shape polymorphism aim to mitigate this by generating kernels that can efficiently handle a constrained range of input shapes, reducing the need for excessive specialization. The runtime must balance the cost of specialization (compilation time, cache management) against the performance gains of using dimension-specific code.
A common technique to impose regularity is padding. If a dimension can vary up to a certain maximum size (e.g., sequence length up to 512), the runtime can allocate buffers based on this maximum size and pad smaller inputs to fit.
Padding:
Bucketing: To reduce the overhead of padding, inputs can be grouped into "buckets" based on their shape characteristics. For instance, sequences of length 1-64 might go into one bucket, 65-128 into another, and so on. Padding is then applied only within each bucket's range.
Comparing memory allocation/processing size for padding to a fixed maximum versus using discrete buckets. Bucketing reduces waste for smaller inputs.
Bucketing requires the runtime to sort or batch inputs dynamically based on shape, adding scheduling complexity but often yielding better resource utilization than naive padding.
More advanced runtimes might incorporate symbolic shape computation. Instead of requiring concrete dimensions immediately, they manipulate symbolic expressions representing shapes (e.g., N
for batch size). Operations are performed on these symbolic shapes until a point where concrete values are needed for memory allocation or kernel execution. This can delay specialization or allocation decisions. Often, this is combined with upper bounding, where the compiler or user provides maximum possible values for dynamic dimensions. The runtime can then use these upper bounds for initial memory reservations (often within an arena allocator) while potentially refining the actual used portion later based on the concrete shapes encountered.
Dynamic shapes fundamentally complicate memory management. Static memory planning, where all buffer addresses and sizes are determined offline, is no longer fully applicable. The runtime memory manager must handle requests for buffers whose sizes are only known during execution. This often necessitates dynamic memory allocators (like arena allocators designed to handle varying sizes efficiently) and strategies to mitigate fragmentation. Failure to manage memory efficiently can lead to out-of-memory errors or significant performance degradation due to allocation overhead or poor data locality.
Similarly, scheduling becomes more complex. The execution time of operations can vary significantly depending on the input shapes. This makes it harder for the scheduler to predict costs, overlap computation and data movement effectively, or balance load across heterogeneous devices. Schedulers might need to adapt dynamically based on observed execution times for specific shapes.
Handling dynamic shapes involves a trade-off between flexibility and performance predictability.
The optimal strategy depends heavily on the specific model architecture, the degree and nature of shape variability, the target hardware capabilities (e.g., JIT compilation speed, memory capacity), and the application's latency requirements. Profiling tools become essential for understanding how dynamic shapes impact memory usage, kernel execution times, and potential JIT overhead within a given runtime system. Effective handling often involves combining multiple techniques, such as using upper bounds for initial allocation, performing runtime shape inference, and employing shape-specialized kernels with bucketing to manage variability efficiently.
© 2025 ApX Machine Learning