While PyTorch offers a wide array of established optimization algorithms in torch.optim
, from SGD to Adam, research and practical applications often benefit from custom optimization strategies. You might need to implement a novel algorithm from a recent paper, adapt an existing optimizer for specific constraints (like layer-wise adaptive learning rates not covered by parameter groups alone), or combine steps from different optimizers. This section details how to create your own optimizers by subclassing torch.optim.Optimizer
, giving you full control over the parameter update process.
torch.optim.Optimizer
Base ClassAt its core, any PyTorch optimizer inherits from the torch.optim.Optimizer
base class. Understanding its structure is essential for building your own. Key components include:
__init__(self, params, defaults)
: The constructor.
params
: An iterable of parameters (tensors) to optimize or an iterable of dictionaries defining parameter groups. Parameter groups allow applying different hyperparameters (like learning rates) to different parts of your model.defaults
: A dictionary containing the default hyperparameters for the optimizer (e.g., {'lr': 0.01, 'momentum': 0.9}
). These defaults are used for parameter groups that don't explicitly override them.__init__
should always be super().__init__(params, defaults)
. This call handles the setup of self.param_groups
, which stores the parameters and their associated hyperparameters.step(self, closure=None)
: This method performs a single optimization step (parameter update).
loss.backward()
..grad
attribute of each parameter.closure
argument is a callable function that re-evaluates the model and returns the loss. Some optimization algorithms, like L-BFGS, require re-evaluating the loss multiple times per step, making the closure necessary. For most common optimizers (SGD, Adam, etc.), the closure
is not needed.zero_grad(self, set_to_none=False)
: Clears the gradients of all optimized parameters.
set_to_none=True
assigns param.grad = None
instead of filling it with zeros. This can sometimes offer a minor performance benefit by freeing memory sooner and avoiding a memory write operation, but requires careful handling downstream if code expects .grad
to always be a Tensor.state
: A dictionary (usually collections.defaultdict(dict)
) that holds the optimizer's state for each parameter. For example, momentum optimizers store momentum buffers here, and Adam stores moving averages of gradients and squared gradients. The state is typically keyed by the parameter object itself (self.state[param]
).
param_groups
: A list of dictionaries. Each dictionary represents a group of parameters and contains the keys:
'params'
: A list of parameter tensors belonging to this group.'lr'
, 'momentum'
, 'weight_decay'
).Let's implement Stochastic Gradient Descent (SGD) with momentum from scratch to illustrate the process. The update rule for momentum SGD is:
vt+1=μvt+gt+1 pt+1=pt−αvt+1Where:
Here's how you can implement this:
import torch
from torch.optim import Optimizer
from collections import defaultdict
class CustomSGD(Optimizer):
"""Implements Stochastic Gradient Descent with Momentum."""
def __init__(self, params, lr=0.01, momentum=0.0, weight_decay=0.0):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)
# Initialize state (although often done lazily in step)
# We don't strictly need this here if we initialize in step
# for group in self.param_groups:
# for p in group['params']:
# self.state[p] = dict(momentum_buffer=None)
@torch.no_grad() # Important: disable gradient tracking within the optimizer step
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad(): # Ensure gradients are enabled for closure
loss = closure()
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue # Skip parameters without gradients
grad = p.grad # Get the gradient tensor
# Apply weight decay (L2 penalty) if specified
# Note: This is the standard way, modifying the gradient
if weight_decay != 0:
grad = grad.add(p, alpha=weight_decay)
# Access and update parameter state (momentum buffer)
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
# Initialize momentum buffer lazily on first step
param_state['momentum_buffer'] = torch.clone(grad).detach()
else:
param_state['momentum_buffer'].mul_(momentum).add_(grad) # v = mu*v + grad
# Get the momentum buffer after update
momentum_buffer = param_state['momentum_buffer']
# Perform the parameter update step
# p = p - lr * momentum_buffer
p.add_(momentum_buffer, alpha=-lr)
return loss
Key Implementation Points:
@torch.no_grad()
: Decorating the step
method with @torch.no_grad()
is crucial. Optimization steps should not be part of the computational graph tracked by autograd.param_groups
and then through the params
within each group.if p.grad is None:
because some parameters in a model might not receive gradients (e.g., if they are not used in the forward pass or detached from the graph).momentum_buffer
is stored in self.state[p]
. It's initialized lazily (the first time step
is called for a parameter) to avoid allocating memory upfront if the parameter never gets a gradient.add_
and mul_
. This modifies the parameter tensor directly without creating new tensors, which is essential for the optimizer to actually update the model weights.lr
, momentum
, weight_decay
) are accessed from the group
dictionary, allowing different groups to have different settings.Using your custom optimizer is just like using a built-in one:
# Assuming 'model' is your torch.nn.Module
# Instantiate the custom optimizer
optimizer = CustomSGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# In your training loop:
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step() # Performs the update using CustomSGD logic
optimizer = CustomSGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3} # Different LR for classifier
], lr=1e-4, momentum=0.9) # Default LR for other params (e.g., base)
closure
function and pass it to optimizer.step(closure)
. Your step
implementation must then call closure()
appropriately, potentially multiple times, usually within a with torch.enable_grad():
block.
# Example conceptual closure usage (specific optimizer logic varies)
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
# In step method (conceptual for L-BFGS like optimizers)
# loss = closure() # May be called multiple times internally
# Use loss and gradients to update parameters...
self.param_groups
generally work seamlessly with PyTorch's learning rate schedulers (torch.optim.lr_scheduler
). The schedulers modify the 'lr'
value within each param_group
, which your custom step
function then reads.self.state[p]
dictionary.By subclassing torch.optim.Optimizer
, you gain the power to implement virtually any parameter update rule, integrating novel optimization research directly into your PyTorch training workflows and fine-tuning the learning process for your specific needs.
© 2025 ApX Machine Learning