In many deep learning scenarios, you'll need to combine multiple tensors into one or break a larger tensor down into smaller pieces. This might involve aggregating results from different processing steps, preparing data batches, or separating features. PyTorch provides several functions for efficiently joining and splitting tensors, building upon the reshaping techniques you've already seen.
Combining tensors is a frequent operation, especially when dealing with batches of data or merging feature representations. PyTorch offers two primary ways to join tensors: concatenation (torch.cat
) and stacking (torch.stack
). The main difference lies in whether they operate along an existing dimension or introduce a new one.
torch.cat
The torch.cat
function joins a sequence of tensors along an existing dimension. All tensors in the sequence must either have the same shape (except in the concatenating dimension) or be empty.
import torch
# Create two tensors
tensor_a = torch.randn(2, 3)
tensor_b = torch.randn(2, 3)
print(f"Tensor A (Shape: {tensor_a.shape}):\n{tensor_a}")
print(f"Tensor B (Shape: {tensor_b.shape}):\n{tensor_b}\n")
# Concatenate along dimension 0 (rows)
# Resulting shape: (2+2, 3) = (4, 3)
cat_dim0 = torch.cat((tensor_a, tensor_b), dim=0)
print(f"Concatenated along dim=0 (Shape: {cat_dim0.shape}):\n{cat_dim0}\n")
# Concatenate along dimension 1 (columns)
# Tensors must match in other dimensions (dim 0)
# Resulting shape: (2, 3+3) = (2, 6)
cat_dim1 = torch.cat((tensor_a, tensor_b), dim=1)
print(f"Concatenated along dim=1 (Shape: {cat_dim1.shape}):\n{cat_dim1}")
# Example with 3D tensors
tensor_c = torch.randn(1, 2, 3)
tensor_d = torch.randn(1, 2, 3)
# Concatenate along dim 0 (batch dimension)
# Resulting shape: (1+1, 2, 3) = (2, 2, 3)
cat_3d_dim0 = torch.cat((tensor_c, tensor_d), dim=0)
print(f"\nConcatenated 3D along dim=0 (Shape: {cat_3d_dim0.shape})")
Notice how torch.cat
increases the size of the specified dimension while keeping other dimensions the same. The tensors must match sizes in all dimensions except the one you're concatenating along.
Visual comparison of
torch.cat
along dimension 0 and dimension 1 for two 2x3 tensors.
torch.stack
In contrast to cat
, torch.stack
joins a sequence of tensors along a new dimension. This is useful when you want to create a batch from individual examples or group related tensors. For stack
to work, all tensors in the input sequence must have the exact same shape.
import torch
# Create two tensors with the same shape
tensor_e = torch.arange(6).reshape(2, 3)
tensor_f = torch.arange(6, 12).reshape(2, 3)
print(f"Tensor E (Shape: {tensor_e.shape}):\n{tensor_e}")
print(f"Tensor F (Shape: {tensor_f.shape}):\n{tensor_f}\n")
# Stack along a new dimension 0
# Resulting shape: (2, 2, 3)
stack_dim0 = torch.stack((tensor_e, tensor_f), dim=0)
print(f"Stacked along new dim=0 (Shape: {stack_dim0.shape}):\n{stack_dim0}\n")
# Stack along a new dimension 1
# Resulting shape: (2, 2, 3)
stack_dim1 = torch.stack((tensor_e, tensor_f), dim=1)
print(f"Stacked along new dim=1 (Shape: {stack_dim1.shape}):\n{stack_dim1}\n")
# Stack along a new dimension 2 (last dimension)
# Resulting shape: (2, 3, 2)
stack_dim2 = torch.stack((tensor_e, tensor_f), dim=2)
print(f"Stacked along new dim=2 (Shape: {stack_dim2.shape}):\n{stack_dim2}")
Visual comparison of
torch.stack
inserting a new dimension atdim=0
anddim=1
. Notice how the original tensors become slices within the new tensor.
Choosing between cat
and stack
depends on whether you want to merge along an existing dimension or create a new one. cat
is often used to combine batches or features horizontally/vertically, while stack
is common for creating batches from individual samples.
Just as you can join tensors, you often need to split them apart. This could involve separating a batch back into individual samples, dividing features from labels, or partitioning data for parallel processing. PyTorch offers torch.split
and torch.chunk
for these tasks.
torch.split
The torch.split
function divides a tensor into chunks along a specified dimension. You can specify either the size of each chunk (if you want equal parts) or a list containing the sizes of each desired chunk.
import torch
# Create a tensor to split
tensor_g = torch.arange(12).reshape(6, 2)
print(f"Original Tensor (Shape: {tensor_g.shape}):\n{tensor_g}\n")
# Split into chunks of size 2 along dimension 0 (rows)
# 6 rows / 2 rows/chunk = 3 chunks
split_equal = torch.split(tensor_g, 2, dim=0)
print("Split into equal chunks of size 2 (dim=0):")
for i, chunk in enumerate(split_equal):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Split into chunks of sizes [1, 2, 3] along dimension 0
# Total size must sum to the dimension size (1 + 2 + 3 = 6)
split_unequal = torch.split(tensor_g, [1, 2, 3], dim=0)
print("\nSplit into unequal chunks [1, 2, 3] (dim=0):")
for i, chunk in enumerate(split_unequal):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Split along dimension 1 (columns)
# Shape: (6, 2). Split into chunks of size 1 along dim=1
split_dim1 = torch.split(tensor_g, 1, dim=1)
print("\nSplit into equal chunks of size 1 (dim=1):")
for i, chunk in enumerate(split_dim1):
# Using squeeze removes the dimension of size 1 for clearer display
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk.squeeze()}")
torch.split
returns a tuple of tensors. If you provide an integer for the split_size_or_sections
argument, PyTorch divides the tensor into chunks of that size along the specified dim
. If the dimension size isn't perfectly divisible by the split size, the last chunk will simply be smaller. If you provide a list of sizes, their sum must equal the size of the dimension being split.
torch.chunk
Alternatively, torch.chunk
splits a tensor into a specified number of chunks along a given dimension. PyTorch attempts to make the chunks as equal in size as possible. Unlike torch.split
which requires specifying chunk sizes, chunk
only needs the desired number of chunks.
import torch
# Create a tensor
tensor_h = torch.arange(10).reshape(5, 2) # Size 5 along dim 0
print(f"Original Tensor (Shape: {tensor_h.shape}):\n{tensor_h}\n")
# Split into 3 chunks along dimension 0
# 5 rows / 3 chunks -> sizes will be [2, 2, 1] ( ceil(5/3)=2 for first chunks)
chunked_tensor = torch.chunk(tensor_h, 3, dim=0)
print("Chunked into 3 parts (dim=0):")
for i, chunk in enumerate(chunked_tensor):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Create another tensor
tensor_i = torch.arange(12).reshape(3, 4) # Size 4 along dim 1
print(f"\nOriginal Tensor (Shape: {tensor_i.shape}):\n{tensor_i}\n")
# Split into 2 chunks along dimension 1
# 4 cols / 2 chunks -> sizes will be [2, 2] ( ceil(4/2)=2 )
chunked_tensor_dim1 = torch.chunk(tensor_i, 2, dim=1)
print("Chunked into 2 parts (dim=1):")
for i, chunk in enumerate(chunked_tensor_dim1):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
torch.chunk
is convenient when you know how many pieces you want, regardless of whether the dimension size divides evenly. torch.split
gives you more control when you need chunks of exact, potentially varying, sizes.
Mastering these joining and splitting operations is important for manipulating data effectively as it flows through different stages of your deep learning pipeline, from initial loading and preprocessing to batching for training and dissecting model outputs.
© 2025 ApX Machine Learning