requires_grad)backward()).grad)torch.nntorch.nn.Module Base Classtorch.nn losses)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoaderMany deep learning scenarios require combining multiple tensors into one or breaking a larger tensor down into smaller pieces. This might involve aggregating results from different processing steps, preparing data batches, or separating features. PyTorch provides several functions for efficiently joining and splitting tensors.
Combining tensors is a frequent operation, especially when dealing with batches of data or merging feature representations. PyTorch offers two primary ways to join tensors: concatenation (torch.cat) and stacking (torch.stack). The main difference lies in whether they operate along an existing dimension or introduce a new one.
torch.catThe torch.cat function joins a sequence of tensors along an existing dimension. All tensors in the sequence must either have the same shape (except in the concatenating dimension) or be empty.
import torch
# Create two tensors
tensor_a = torch.randn(2, 3)
tensor_b = torch.randn(2, 3)
print(f"Tensor A (Shape: {tensor_a.shape}):\n{tensor_a}")
print(f"Tensor B (Shape: {tensor_b.shape}):\n{tensor_b}\n")
# Concatenate along dimension 0 (rows)
# Resulting shape: (2+2, 3) = (4, 3)
cat_dim0 = torch.cat((tensor_a, tensor_b), dim=0)
print(f"Concatenated along dim=0 (Shape: {cat_dim0.shape}):\n{cat_dim0}\n")
# Concatenate along dimension 1 (columns)
# Tensors must match in other dimensions (dim 0)
# Resulting shape: (2, 3+3) = (2, 6)
cat_dim1 = torch.cat((tensor_a, tensor_b), dim=1)
print(f"Concatenated along dim=1 (Shape: {cat_dim1.shape}):\n{cat_dim1}")
# Example with 3D tensors
tensor_c = torch.randn(1, 2, 3)
tensor_d = torch.randn(1, 2, 3)
# Concatenate along dim 0 (batch dimension)
# Resulting shape: (1+1, 2, 3) = (2, 2, 3)
cat_3d_dim0 = torch.cat((tensor_c, tensor_d), dim=0)
print(f"\nConcatenated 3D along dim=0 (Shape: {cat_3d_dim0.shape})")
Notice how torch.cat increases the size of the specified dimension while keeping other dimensions the same. The tensors must match sizes in all dimensions except the one you're concatenating along.
Visual comparison of
torch.catalong dimension 0 and dimension 1 for two 2x3 tensors.
torch.stackIn contrast to cat, torch.stack joins a sequence of tensors along a new dimension. This is useful when you want to create a batch from individual examples or group related tensors. For stack to work, all tensors in the input sequence must have the exact same shape.
import torch
# Create two tensors with the same shape
tensor_e = torch.arange(6).reshape(2, 3)
tensor_f = torch.arange(6, 12).reshape(2, 3)
print(f"Tensor E (Shape: {tensor_e.shape}):\n{tensor_e}")
print(f"Tensor F (Shape: {tensor_f.shape}):\n{tensor_f}\n")
# Stack along a new dimension 0
# Resulting shape: (2, 2, 3)
stack_dim0 = torch.stack((tensor_e, tensor_f), dim=0)
print(f"Stacked along new dim=0 (Shape: {stack_dim0.shape}):\n{stack_dim0}\n")
# Stack along a new dimension 1
# Resulting shape: (2, 2, 3)
stack_dim1 = torch.stack((tensor_e, tensor_f), dim=1)
print(f"Stacked along new dim=1 (Shape: {stack_dim1.shape}):\n{stack_dim1}\n")
# Stack along a new dimension 2 (last dimension)
# Resulting shape: (2, 3, 2)
stack_dim2 = torch.stack((tensor_e, tensor_f), dim=2)
print(f"Stacked along new dim=2 (Shape: {stack_dim2.shape}):\n{stack_dim2}")
Visual comparison of
torch.stackinserting a new dimension atdim=0anddim=1. Notice how the original tensors become slices within the new tensor.
Choosing between cat and stack depends on whether you want to merge along an existing dimension or create a new one. cat is often used to combine batches or features horizontally/vertically, while stack is common for creating batches from individual samples.
Just as you can join tensors, you often need to split them apart. This could involve separating a batch back into individual samples, dividing features from labels, or partitioning data for parallel processing. PyTorch offers torch.split and torch.chunk for these tasks.
torch.splitThe torch.split function divides a tensor into chunks along a specified dimension. You can specify either the size of each chunk (if you want equal parts) or a list containing the sizes of each desired chunk.
import torch
# Create a tensor to split
tensor_g = torch.arange(12).reshape(6, 2)
print(f"Original Tensor (Shape: {tensor_g.shape}):\n{tensor_g}\n")
# Split into chunks of size 2 along dimension 0 (rows)
# 6 rows / 2 rows/chunk = 3 chunks
split_equal = torch.split(tensor_g, 2, dim=0)
print("Split into equal chunks of size 2 (dim=0):")
for i, chunk in enumerate(split_equal):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Split into chunks of sizes [1, 2, 3] along dimension 0
# Total size must sum to the dimension size (1 + 2 + 3 = 6)
split_unequal = torch.split(tensor_g, [1, 2, 3], dim=0)
print("\nSplit into unequal chunks [1, 2, 3] (dim=0):")
for i, chunk in enumerate(split_unequal):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Split along dimension 1 (columns)
# Shape: (6, 2). Split into chunks of size 1 along dim=1
split_dim1 = torch.split(tensor_g, 1, dim=1)
print("\nSplit into equal chunks of size 1 (dim=1):")
for i, chunk in enumerate(split_dim1):
# Using squeeze removes the dimension of size 1 for clearer display
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk.squeeze()}")
torch.split returns a tuple of tensors. If you provide an integer for the split_size_or_sections argument, PyTorch divides the tensor into chunks of that size along the specified dim. If the dimension size isn't perfectly divisible by the split size, the last chunk will simply be smaller. If you provide a list of sizes, their sum must equal the size of the dimension being split.
torch.chunkAlternatively, torch.chunk splits a tensor into a specified number of chunks along a given dimension. PyTorch attempts to make the chunks as equal in size as possible. Unlike torch.split which requires specifying chunk sizes, chunk only needs the desired number of chunks.
import torch
# Create a tensor
tensor_h = torch.arange(10).reshape(5, 2) # Size 5 along dim 0
print(f"Original Tensor (Shape: {tensor_h.shape}):\n{tensor_h}\n")
# Split into 3 chunks along dimension 0
# 5 rows / 3 chunks -> sizes will be [2, 2, 1] ( ceil(5/3)=2 for first chunks)
chunked_tensor = torch.chunk(tensor_h, 3, dim=0)
print("Chunked into 3 parts (dim=0):")
for i, chunk in enumerate(chunked_tensor):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
print("-" * 20)
# Create another tensor
tensor_i = torch.arange(12).reshape(3, 4) # Size 4 along dim 1
print(f"\nOriginal Tensor (Shape: {tensor_i.shape}):\n{tensor_i}\n")
# Split into 2 chunks along dimension 1
# 4 cols / 2 chunks -> sizes will be [2, 2] ( ceil(4/2)=2 )
chunked_tensor_dim1 = torch.chunk(tensor_i, 2, dim=1)
print("Chunked into 2 parts (dim=1):")
for i, chunk in enumerate(chunked_tensor_dim1):
print(f" Chunk {i} (Shape: {chunk.shape}):\n{chunk}")
torch.chunk is convenient when you know how many pieces you want, regardless of whether the dimension size divides evenly. torch.split gives you more control when you need chunks of exact, potentially varying, sizes.
Mastering these joining and splitting operations is important for manipulating data effectively as it flows through different stages of your deep learning pipeline, from initial loading and preprocessing to batching for training and dissecting model outputs.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with