Masterclass
While assigning fixed weights to different data sources gives you control over the overall mixture proportions, sometimes a more dynamic approach is beneficial. You might want to adjust how strongly the assigned weights influence the actual sampling probability during training. This is where temperature-based sampling comes into play, borrowing its name and concept from statistical physics and its common use in controlling the output randomness of generation models. In the context of data sampling, temperature allows you to modulate the "sharpness" of the probability distribution derived from your source weights.
Imagine you have several data sources, each assigned a score or log-weight wi reflecting its perceived importance or quality. Without temperature scaling, the probability of sampling from source i might be determined by a standard softmax function based on these scores. Temperature T introduces a scaling factor to these scores before the softmax calculation.
The probability P(i∣T) of selecting data source i given a temperature T is calculated as:
P(i∣T)=∑jexp(wj/T)exp(wi/T)Here, wi is the score or log-weight for source i, and the sum is over all available sources j. The temperature T is a positive value that controls the shape of this probability distribution:
This mechanism provides a smooth way to interpolate between uniform sampling (T→∞) and greedy sampling based on the highest weight (T→0+).
Temperature-based sampling is particularly useful when combined with annealing schedules. You might start training with a higher temperature (T>1) to ensure the model sees a wide variety of data from all available sources, promoting exploration. As training progresses, you can gradually decrease the temperature (anneal it towards T=1 or even lower). This shift focuses the training towards higher-quality or more relevant data sources later on, helping the model refine its capabilities based on the prioritized data mixture.
Consider a scenario with three data sources: Web Text (score w1=2.0), Books (score w2=3.0), and Code (score w3=1.0). Let's see how the sampling probabilities change with temperature.
import torch
import torch.nn.functional as F
# Scores (log-weights) for data sources
# Higher score means higher preference at T=1
scores = torch.tensor([2.0, 3.0, 1.0]) # Web, Books, Code
temperatures = [0.5, 1.0, 2.0, 10.0]
source_names = ["Web Text", "Books", "Code"]
print("Source Scores:", dict(zip(source_names, scores.tolist())))
print("-" * 30)
for T in temperatures:
# Apply temperature scaling and calculate probabilities via softmax
probs = F.softmax(scores / T, dim=0)
print(f"Temperature T = {T:.1f}")
for name, prob in zip(source_names, probs.tolist()):
print(f" P({name}): {prob:.4f}")
print("-" * 30)
# Example of sampling based on probabilities for T=1.0
T_sample = 1.0
sampling_probs = F.softmax(scores / T_sample, dim=0)
# Sample a source index based on the calculated probabilities
# In a real scenario, you'd sample thousands/millions of times
num_samples = 5
sampled_indices = torch.multinomial(
sampling_probs,
num_samples=num_samples,
replacement=True
)
print(f"\nExample Sampling (T={T_sample:.1f}):")
sampled_sources = [source_names[i] for i in sampled_indices]
print(f"Sampled source indices: {sampled_indices.tolist()}")
print(f"Sampled source names: {sampled_sources}")
Running this code produces output similar to:
Source Scores: {'Web Text': 2.0, 'Books': 3.0, 'Code': 1.0}
------------------------------
Temperature T = 0.5
P(Web Text): 0.1173
P(Books): 0.8681
P(Code): 0.0146
------------------------------
Temperature T = 1.0
P(Web Text): 0.2419
P(Books): 0.6577
P(Code): 0.0890
------------------------------
Temperature T = 2.0
P(Web Text): 0.3067
P(Books): 0.4487
P(Code): 0.2179
------------------------------
Temperature T = 10.0
P(Web Text): 0.3293
P(Books): 0.3433
P(Code): 0.3141
------------------------------
Example Sampling (T=1.0):
Sampled source indices: [1, 1, 1, 0, 1]
Sampled source names: ['Books', 'Books', 'Books', 'Web Text', 'Books']
Notice how at T=0.5, the probability for "Books" (highest score) is dominant (0.8681). At T=1.0, "Books" is still preferred (0.6577), but other sources have a reasonable chance. As T increases to 2.0 and 10.0, the probabilities become much closer, approaching uniform (1/3≈0.3333).
Probabilities of sampling from three sources ("Web Text", "Books", "Code" with scores 2.0, 3.0, 1.0 respectively) at different temperatures. Lower temperatures sharpen the distribution towards the highest-scored source ("Books"), while higher temperatures flatten it towards uniform.
Dataset
that, for each batch or sample, first selects a source using temperature-based sampling and then retrieves an item from that source's dataset.Temperature-based sampling provides a flexible knob for controlling the data mixture dynamically throughout training, enabling strategies that balance broad exploration with focused exploitation of high-value data sources.
© 2025 ApX Machine Learning