As you build and train increasingly sophisticated models, managing memory becomes a significant aspect of development and debugging, particularly when working with GPUs which often have more limited memory than CPUs. Understanding how PyTorch handles memory allocation and how your operations influence it is essential for efficiency and avoiding common out-of-memory errors. This section examines PyTorch's memory management mechanisms, connecting them to the tensor structures and autograd processes discussed earlier.
At its core, a PyTorch tensor (torch.Tensor
) is a view over a contiguous block of memory managed by a torch.Storage
object. The Storage
object holds the actual numerical data, while the Tensor
object contains metadata like shape (size), stride, and data type (dtype
), along with information about its position within the Storage
.
Multiple tensors can share the same underlying Storage
. For example, slicing a tensor or using operations like view()
often creates a new tensor object that points to the same storage but with different metadata.
import torch
# Create a tensor; PyTorch allocates storage
x = torch.randn(2, 3)
print(f"x storage: {x.storage().data_ptr()}")
# Slicing creates a new tensor view sharing the storage
y = x[0, :]
print(f"y storage: {y.storage().data_ptr()}") # Same pointer
print(f"Do x and y share storage? {x.storage().data_ptr() == y.storage().data_ptr()}")
# Modifying y affects x because they share storage
y.fill_(1.0)
print("x after modifying y:\n", x)
This storage sharing is highly efficient as it avoids unnecessary data copies. However, it's important to be aware of it, especially when performing in-place operations.
The layout of a tensor in memory is determined by its stride. A tensor is considered contiguous if its elements are laid out sequentially in memory row-by-row (for 2D tensors) without gaps. Non-contiguous tensors can arise from operations like transposing or certain types of indexing.
# Contiguous tensor
a = torch.arange(6).reshape(2, 3)
print(f"a is contiguous: {a.is_contiguous()}, Stride: {a.stride()}") # Stride: (3, 1)
# Transposing creates a non-contiguous view
b = a.t()
print(f"b is contiguous: {b.is_contiguous()}, Stride: {b.stride()}") # Stride: (1, 3)
# Accessing elements is still correct, but memory access pattern differs
print("b:\n", b)
# Some PyTorch functions require contiguous tensors
# Attempting an operation like view on a non-contiguous tensor might fail
try:
b.view(-1)
except RuntimeError as e:
print(f"\nError viewing non-contiguous tensor: {e}")
# Use .contiguous() to get a contiguous copy
c = b.contiguous()
print(f"c is contiguous: {c.is_contiguous()}, Stride: {c.stride()}") # Stride: (2, 1)
print("c (contiguous version of b):\n", c)
print(f"Does b and c share storage? {b.storage().data_ptr() == c.storage().data_ptr()}") # False, new storage
While PyTorch operations often handle non-contiguous tensors correctly, some low-level operations or interfaces (like exporting to NumPy or certain custom extensions) might require contiguous data. Calling .contiguous()
creates a new tensor with a fresh, contiguous copy of the data if the original tensor wasn't already contiguous. This incurs a memory copy overhead.
The data type (dtype
) also directly impacts memory usage. A torch.float32
tensor uses 4 bytes per element, while torch.float16
uses 2 bytes, and torch.int64
uses 8 bytes. Choosing the appropriate dtype
is fundamental for memory efficiency.
Allocating and deallocating memory on GPUs using CUDA APIs (cudaMalloc
, cudaFree
) can be slow. To mitigate this, PyTorch employs a caching memory allocator for GPU tensors. When a tensor is freed (e.g., goes out of scope and its reference count drops to zero), the memory it occupied isn't necessarily returned to the GPU operating system immediately. Instead, PyTorch holds onto this block of memory in a cache.
When a new tensor needs to be allocated, PyTorch first checks its cache for an appropriately sized free block. If found, it reuses that block, avoiding the expensive call to the CUDA driver. This significantly speeds up tensor creation and deletion, which happens frequently during training.
A simplified view of the PyTorch Caching Allocator interacting with the CUDA driver and tensor memory.
You can inspect the state of the caching allocator:
torch.cuda.memory_allocated()
: Returns the total GPU memory currently occupied by tensors (in bytes) for the default device.torch.cuda.memory_reserved()
or torch.cuda.memory_cached()
(deprecated): Returns the total GPU memory managed by the caching allocator (both allocated tensors and cached free blocks).torch.cuda.max_memory_allocated()
: Returns the maximum GPU memory occupied by tensors at any point during execution since the start or the last reset.torch.cuda.reset_peak_memory_stats()
: Resets the peak memory counter.torch.cuda.memory_summary()
: Provides a detailed report of allocated and cached memory, often useful for identifying fragmentation.Sometimes, you might want to clear the cached memory, perhaps to make it available to other GPU applications or libraries. You can use torch.cuda.empty_cache()
.
# Requires a GPU
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Initial allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MiB")
print(f"Initial reserved: {torch.cuda.memory_reserved(device) / 1024**2:.2f} MiB")
# Allocate some tensors
t1 = torch.randn(1024, 1024, device=device)
t2 = torch.randn(512, 512, device=device)
print(f"\nAfter allocation:")
print(f"Allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MiB")
print(f"Reserved: {torch.cuda.memory_reserved(device) / 1024**2:.2f} MiB")
# Delete tensors
del t1
del t2
print(f"\nAfter deleting tensors (before empty_cache):")
# Allocated memory drops, but reserved memory remains high due to caching
print(f"Allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MiB")
print(f"Reserved: {torch.cuda.memory_reserved(device) / 1024**2:.2f} MiB")
# Clear the cache
torch.cuda.empty_cache()
print(f"\nAfter empty_cache:")
# Reserved memory also drops (though maybe not to zero due to internal allocations)
print(f"Allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MiB")
print(f"Reserved: {torch.cuda.memory_reserved(device) / 1024**2:.2f} MiB")
else:
print("CUDA not available, skipping GPU memory examples.")
Important: torch.cuda.empty_cache()
does not free memory currently used by active tensors. It only releases the cached blocks that are not currently backing any tensor. It's primarily useful for releasing memory back to the system for other processes, not for reducing the memory footprint of your running PyTorch script if tensors are still live. It also incurs a performance cost, as subsequent allocations will need to go back to the driver.
A side effect of the caching allocator is fragmentation. If you allocate and free tensors of varying sizes, the cache might end up holding many small, non-contiguous free blocks. Even if the total size of these cached blocks is large, you might be unable to allocate a large contiguous block, leading to an out-of-memory (OOM) error. torch.cuda.memory_summary()
can help diagnose fragmentation issues.
The autograd engine significantly impacts memory usage. To compute gradients during the backward pass, autograd typically needs to store intermediate activations (outputs of forward operations) that are part of the computational graph.
requires_grad=True
), PyTorch builds a graph storing these operations and references to the tensors involved. These references keep tensors alive in memory, even if they might seem to go out of scope in your Python code.loss.backward()
, autograd traverses this graph backward. It uses the stored intermediate values to compute gradients. Once a gradient has been computed and is no longer needed for further computations in the backward pass, the buffer holding the corresponding intermediate activation is often freed.retain_graph=True
: If you call backward(retain_graph=True)
, PyTorch preserves the graph and the intermediate activation buffers even after the backward pass completes. This allows you to call backward()
multiple times (e.g., for calculating gradients of different losses w.r.t. the same parameters), but it comes at the cost of holding onto potentially large amounts of memory. Use it only when necessary.torch.no_grad()
: Wrapping code in a with torch.no_grad():
block signals to PyTorch that operations within this block should not be tracked by autograd. This prevents the creation of the computational graph for these operations and avoids storing intermediate activations, saving significant memory. It's standard practice to use this context manager during validation or inference loops.detach()
: Calling .detach()
on a tensor creates a new tensor that shares the same storage but is detached from the computational graph. It doesn't require gradients, and no operations involving it will be tracked. This is useful if you need to use a tensor's value without tracking its history (e.g., for logging or plotting).Consider this simple example:
# Setup
a = torch.randn(100, 100, requires_grad=True)
b = torch.randn(100, 100, requires_grad=True)
# Operations tracked by autograd
c = a * b
d = c.sin()
loss = d.mean()
# Intermediate tensors 'c' and 'd' are kept in memory
# because they are needed for the backward pass.
# Calling backward frees buffers (unless retain_graph=True)
loss.backward() # Computes gradients for a and b
# Now, let's try without tracking gradients
with torch.no_grad():
c_no_grad = a * b # Operation performed, but not tracked
d_no_grad = c_no_grad.sin()
loss_no_grad = d_no_grad.mean()
# PyTorch doesn't need to store 'c_no_grad' for a future backward pass
# Memory for intermediate results can be potentially freed sooner.
print(f"Gradient of a: {'Exists' if a.grad is not None else 'None'}")
# loss_no_grad.backward() # This would raise an error as history wasn't tracked.
Beyond understanding the fundamentals, here are practical strategies:
Scope and del
: Python's garbage collector reclaims memory when objects are no longer referenced. Ensure large tensors that are no longer needed go out of scope. If necessary, use the del
statement explicitly to remove references, particularly before potentially memory-intensive operations like backward()
or allocating new large tensors.
def process_data(data):
intermediate = data * 2 # Large intermediate tensor
result = intermediate.sum()
# 'intermediate' might stay in memory longer if not deleted
del intermediate # Explicitly remove reference
return result
In-Place Operations: Operations ending with an underscore (_
), like add_()
, relu_()
, modify a tensor directly without creating a new one. This saves memory by avoiding allocation of a new tensor for the result.
Caution: Modifying tensors in-place that are needed for gradient computation can corrupt the backward pass. Autograd tracks in-place operations and will raise an error if it detects such a modification that interferes with gradient calculation. Use them carefully, often on tensors that are leaves in the graph or where you are certain it won't affect needed gradients.
x = torch.randn(1000, 1000)
y = torch.randn(1000, 1000)
# Not in-place: creates a new tensor z
z = x + y
# In-place: modifies x directly, saves memory for the result tensor
x.add_(y) # x now holds the result of x + y
Gradient Checkpointing (Activation Checkpointing): For models with very deep architectures where storing all intermediate activations consumes too much memory, gradient checkpointing offers a trade-off. Instead of storing all activations during the forward pass, it stores only a subset. During the backward pass, it recomputes the necessary activations on-the-fly. This uses more computation time but drastically reduces peak memory usage. PyTorch provides torch.utils.checkpoint.checkpoint
for this purpose.
Mixed-Precision Training: Using lower-precision data types like torch.float16
or torch.bfloat16
halves the memory required for storing activations, gradients, and parameters compared to torch.float32
. Libraries like torch.cuda.amp
(Automatic Mixed Precision) help manage this effectively (covered in Chapter 3).
Data Loading and Batch Size: Ensure your data loading pipeline is efficient. If facing OOM errors, reducing the batch size is often the first step, as activations and their gradients scale with batch size.
torch.cuda.memory_summary()
to see the distribution of allocated blocks and cached fragments. High fragmentation can cause OOM even if total free memory seems sufficient.torch.cuda.memory_allocated()
at different points in your training loop to pinpoint where memory usage spikes.torch.no_grad()
context.
all_losses.append(loss)
instead of all_losses.append(loss.item())
or all_losses.append(loss.detach())
. Storing the original loss
tensor keeps its entire computation graph alive..item()
to get the Python number from a single-element tensor or .detach()
if you need the tensor value without its history.Effective memory management is often an iterative process of understanding your model's behavior, applying appropriate techniques, and using PyTorch's tools to inspect and debug memory usage. A solid grasp of these concepts is indispensable when scaling up to larger datasets and more complex architectures.
© 2025 ApX Machine Learning