Often, you'll find that the tensor you have isn't quite in the right structure for the next step in your computation, especially when feeding data into specific neural network layers. PyTorch provides flexible tools to change a tensor's shape or rearrange its dimensions without altering the underlying data elements themselves. This section covers the primary methods: view()
, reshape()
, and permute()
.
view()
and reshape()
Both view()
and reshape()
allow you to change the dimensions of a tensor, provided the total number of elements remains constant. They are highly useful for tasks like flattening a multi-dimensional tensor before passing it to a linear layer or adding/removing dimensions of size 1.
view()
The view()
method returns a new tensor that shares the same underlying data as the original tensor but has a different shape. It's very efficient because it avoids copying data. However, view()
requires the tensor to be contiguous in memory. A contiguous tensor is one whose elements are stored sequentially in memory without gaps, following the dimension order. Most freshly created tensors are contiguous, but some operations (like slicing or transposing with t()
) can create non-contiguous tensors.
Let's look at an example:
import torch
# Create a contiguous tensor
x = torch.arange(12) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
print(f"Original tensor: {x}")
print(f"Original shape: {x.shape}")
print(f"Is contiguous? {x.is_contiguous()}")
# Reshape using view()
y = x.view(3, 4)
print("\nTensor after view(3, 4):")
print(y)
print(f"New shape: {y.shape}")
print(f"Shares storage with x? {y.storage().data_ptr() == x.storage().data_ptr()}") # Check if they share memory
print(f"Is y contiguous? {y.is_contiguous()}")
# Try another view
z = y.view(2, 6)
print("\nTensor after view(2, 6):")
print(z)
print(f"New shape: {z.shape}")
print(f"Shares storage with x? {z.storage().data_ptr() == x.storage().data_ptr()}")
print(f"Is z contiguous? {z.is_contiguous()}")
You can use -1
in one dimension of the view()
call, and PyTorch will automatically infer the correct size for that dimension based on the total number of elements and the sizes of the other dimensions.
# Using -1 for inference
w = x.view(2, 2, -1) # Infers the last dimension to be 3 (12 / (2*2) = 3)
print("\nTensor after view(2, 2, -1):")
print(w)
print(f"New shape: {w.shape}")
If you try to call view()
on a non-contiguous tensor, you'll get a RuntimeError
.
# Example of view() failing on a non-contiguous tensor
a = torch.arange(12).view(3, 4)
b = a.t() # Transpose creates a non-contiguous tensor
print(f"\nIs b contiguous? {b.is_contiguous()}")
try:
c = b.view(12)
except RuntimeError as e:
print(f"\nError trying b.view(12): {e}")
reshape()
The reshape()
method behaves similarly to view()
but offers more flexibility. It will attempt to return a view if the tensor is contiguous for the target shape. If it's not possible to return a view (e.g., because the original tensor isn't contiguous in a way compatible with the new shape), reshape()
will copy the data into a new, contiguous tensor with the desired shape. This makes reshape()
generally safer and more versatile, although potentially less performant if a copy occurs.
Let's revisit the transpose example using reshape()
:
# Using reshape() on the non-contiguous tensor 'b'
print(f"\nOriginal non-contiguous tensor b:\n{b}")
print(f"Shape of b: {b.shape}")
print(f"Is b contiguous? {b.is_contiguous()}")
# Reshape works even if 'b' is not contiguous
c = b.reshape(12)
print(f"\nTensor c after b.reshape(12):\n{c}")
print(f"Shape of c: {c.shape}")
print(f"Is c contiguous? {c.is_contiguous()}")
# Check if 'c' shares storage with 'b'. It likely won't because reshape probably copied.
print(f"Shares storage with b? {c.storage().data_ptr() == b.storage().data_ptr()}")
# Reshape can also infer dimensions with -1
d = b.reshape(2, -1) # Infers the last dimension to be 6
print(f"\nTensor d after b.reshape(2, -1):\n{d}")
print(f"Shape of d: {d.shape}")
When to use which?
view()
if you are certain the tensor is contiguous and you want to guarantee no data copy occurs for maximum performance. Be prepared to handle potential RuntimeError
s if the contiguity assumption is wrong.reshape()
for a more robust approach that works on both contiguous and non-contiguous tensors. It will return a view if possible, otherwise it makes a copy. This is often the preferred method unless performance is absolutely critical and you can ensure contiguity.permute()
While view()
and reshape()
change the shape by rearranging how elements are interpreted across dimensions, permute()
explicitly swaps the dimensions themselves. It doesn't change the total number of elements or the shape in terms of the number of elements along each axis, but it changes which axis corresponds to which original dimension.
Imagine you have image data stored as (Channels, Height, Width) but need it in (Height, Width, Channels) format for a specific library or visualization. permute()
is the tool for this. You provide the desired order of dimensions as arguments.
# Create a 3D tensor (e.g., representing C, H, W)
image_tensor = torch.randn(3, 32, 32) # Channels, Height, Width
print(f"Original shape: {image_tensor.shape}") # torch.Size([3, 32, 32])
# Permute to (Height, Width, Channels)
permuted_tensor = image_tensor.permute(1, 2, 0) # Specify new order: Dim 1, Dim 2, Dim 0
print(f"Permuted shape: {permuted_tensor.shape}") # torch.Size([32, 32, 3])
# Permute usually returns a non-contiguous view
print(f"Is permuted_tensor contiguous? {permuted_tensor.is_contiguous()}")
# Permuting back
original_again = permuted_tensor.permute(2, 0, 1) # Back to C, H, W
print(f"Shape after permuting back: {original_again.shape}") # torch.Size([3, 32, 32])
print(f"Is original_again contiguous? {original_again.is_contiguous()}") # Might still be non-contiguous
# Check storage sharing
print(f"Shares storage with original? {original_again.storage().data_ptr() == image_tensor.storage().data_ptr()}")
Like view()
, permute()
returns a tensor that shares the underlying data with the original. It does not copy data. However, the resulting tensor is typically not contiguous. If you need a contiguous tensor after permuting (e.g., to use view()
subsequently), you can chain a call to .contiguous()
:
# Make the permuted tensor contiguous
contiguous_permuted = permuted_tensor.contiguous()
print(f"\nIs contiguous_permuted contiguous? {contiguous_permuted.is_contiguous()}")
# Now view() can be used safely
flattened_permuted = contiguous_permuted.view(-1)
print(f"Shape after flattening: {flattened_permuted.shape}")
Mastering view()
, reshape()
, and permute()
gives you precise control over the structure of your tensors, a necessary skill for adapting data to the requirements of different PyTorch operations and model layers. Remember the trade-offs: view()
is fast but requires contiguity, reshape()
is flexible but might copy, and permute()
swaps dimensions without copying but often results in a non-contiguous tensor.
© 2025 ApX Machine Learning