Training a VAE variant for disentangled representation learning involves practical application of theoretical concepts. This includes specifically training the $ \beta $-VAE to achieve better disentanglement. Evaluation of its success uses metrics such as the Mutual Information Gap (MIG) and Separated Attribute Predictability (SAP).The goal here isn't to provide a complete, copy-paste codebase, but rather to outline the essential steps and considerations, allowing you to experiment and deepen your understanding. We assume you're comfortable implementing a standard VAE in a framework like PyTorch or TensorFlow.Setting the Stage: Dataset and LibrariesFor disentanglement experiments, synthetic datasets with known ground-truth factors of variation are invaluable. The dSprites dataset is a popular choice. It consists of 2D shapes (squares, ellipses, hearts) generated from 6 independent latent factors: color (always white), shape, scale, orientation, X-position, and Y-position. Having access to these true factors allows us to quantitatively measure how well our model disentangles them.You'll need your standard deep learning toolkit:A deep learning framework (PyTorch or TensorFlow).NumPy for numerical operations.Scikit-learn for potential helper functions, especially for metric calculation (e.g., mutual_info_regression or for training simple classifiers).Matplotlib or Seaborn for visualizations.Implementing and Training a $ \beta $-VAERecall from our discussion that the $ \beta $-VAE modifies the standard VAE objective by introducing a coefficient $ \beta $ to the KL divergence term:$$ \mathcal{L}{\beta-VAE} = \mathbb{E}{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - \beta \cdot D_{KL}(q_{\phi}(z|x) || p(z)) $$A $ \beta > 1 $ places a stronger constraint on the KL divergence, encouraging the approximate posterior $ q_{\phi}(z|x) $ to be closer to the prior $ p(z) $ (typically an isotropic Gaussian $ \mathcal{N}(0, I) $). This pressure can encourage the model to find more disentangled representations.1. Model Architecture: Your VAE architecture can be a standard convolutional setup for image data like dSprites.Encoder: A few convolutional layers followed by fully connected layers that output the mean $ \mu_z $ and log-variance $ \log \sigma_z^2 $ of the latent distribution $ q_{\phi}(z|x) $.Decoder: Fully connected layers followed by deconvolutional (or upsampling + convolutional) layers that reconstruct the input image from a latent sample $ z $.2. The $ \beta $-VAE Loss Function: The implementation change from a standard VAE is minimal. Assuming reconstruction_loss is your negative log-likelihood term (e.g., binary cross-entropy or mean squared error) and kl_divergence is the KL term, your combined loss calculation becomes:# Pseudocode for beta-VAE loss # mu, log_var are outputs from the encoder # x_reconstructed is output from the decoder # x_original is the input image # beta is the hyperparameter reconstruction_loss = reconstruction_criterion(x_reconstructed, x_original) # kl_divergence for N(mu, sigma^2) vs N(0, I) # 0.5 * sum(1 + log_var - mu.pow(2) - log_var.exp()) kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) kl_divergence = torch.mean(kl_divergence) # Average over batch total_loss = reconstruction_loss + beta * kl_divergence # Backpropagate total_loss3. Training Insights:Train your $ \beta $-VAE for a sufficient number of epochs.Experiment with different values of $ \beta $. Common values to try might be $ \beta=1 $ (standard VAE), $ \beta=2, 4, 8, 16 $.Monitor both the reconstruction loss and the KL divergence. You'll likely observe that as $ \beta $ increases, the KL divergence term gets smaller (closer to the target of 0 for each dimension if using a standard Gaussian prior), but the reconstruction quality might degrade. This is the classic trade-off.Evaluating DisentanglementOnce your models are trained, you need to evaluate how disentangled their learned representations are.1. Qualitative Evaluation: Latent Traversal A simple yet insightful way to qualitatively assess disentanglement is by performing latent traversals.Take an input image and encode it to get its latent mean $ \mu_z $.Select one latent dimension $ z_i $.Generate new images by decoding variations of $ \mu_z $ where only $ z_i $ is varied across a range (e.g., from -3 to 3 standard deviations if $ p(z) = \mathcal{N}(0,I) $), while keeping other latent dimensions fixed at their values from $ \mu_z $.If the $i$-th latent dimension is well-disentangled, varying only $ z_i $ should result in changes to a single, interpretable factor of variation in the generated images (e.g., only scale changes, or only x-position changes).You can create a grid of images where each row (or column) corresponds to traversing a different latent dimension. This visual inspection can be very revealing.2. Quantitative Metrics For a more rigorous assessment, we use quantitative metrics. These typically require the ground-truth factor labels from the dataset.Mutual Information Gap (MIG) MIG attempts to measure the extent to which each ground-truth factor is captured by a single latent dimension. For each ground-truth factor $ y_k $:Encode a batch of data to get latent means $ Z $.For each latent dimension $ z_j $, compute the empirical mutual information $ I(z_j; y_k) $. This often involves discretizing $ z_j $ into bins.Find the latent dimension $ z_{j^*} $ that has the highest mutual information with $ y_k $: $ \max_j I(z_j; y_k) $.Find the latent dimension $ z_{j^{**}} $ that has the second highest mutual information with $ y_k $: $ \max_{j \neq j^*} I(z_j; y_k) $.The gap for factor $ y_k $ is $ \frac{I(z_{j^*}; y_k) - I(z_{j^{**}}; y_k)}{H(y_k)} $, where $ H(y_k) $ is the entropy of the ground-truth factor $ y_k $. The final MIG score is the average of these gaps over all ground-truth factors.A higher MIG score suggests better disentanglement, as it implies that each factor is primarily represented by one latent dimension, with a clear "gap" to the next most informative one.# Pseudocode for calculating MIG (simplified) # latents: (N_samples, N_latents) - encoded means from VAE # factors: (N_samples, N_factors) - ground-truth factor values # n_bins_for_latent_discretization = 20 def calculate_mig(latents, factors): num_latents = latents.shape[1] num_factors = factors.shape[1] mig_scores_per_factor = [] for k in range(num_factors): # For each ground truth factor y_k y_k = factors[:, k] # Estimate H(y_k) - might need discretization if continuous, or use known values # For dSprites, factors are discrete, so H(y_k) can be computed directly h_y_k = calculate_entropy(y_k) mutual_informations = [] for j in range(num_latents): # For each latent z_j z_j_discretized = discretize_latent(latents[:, j], n_bins_for_latent_discretization) # Use sklearn.metrics.mutual_info_score or similar mi_zj_yk = compute_mutual_information(z_j_discretized, y_k) mutual_informations.append(mi_zj_yk) sorted_mi = sorted(mutual_informations, reverse=True) if len(sorted_mi) < 2: continue # Not enough latents to compute gap gap_k = (sorted_mi[0] - sorted_mi[1]) / h_y_k if h_y_k > 0 else 0 mig_scores_per_factor.append(gap_k) return sum(mig_scores_per_factor) / len(mig_scores_per_factor) if mig_scores_per_factor else 0 # Helper functions like discretize_latent, calculate_entropy, compute_mutual_information # would need to be implemented. For dSprites, factors are discrete, simplifying entropy. # sklearn.feature_selection.mutual_info_regression (if y_k continuous) # or sklearn.metrics.mutual_info_score (if y_k discrete, after discretizing z_j) can be used.Separated Attribute Predictability (SAP) SAP measures disentanglement by assessing how well each latent dimension predicts a single ground-truth factor. For each ground-truth factor $ y_k $:Train a simple linear classifier (e.g., Logistic Regression) to predict $ y_k $ using only latent dimension $ z_j $. Calculate its accuracy (or R-squared if $ y_k $ is continuous). Do this for all $ j $.Find the latent dimension $ z_{j^*} $ that is most predictive of $ y_k $.Find the latent dimension $ z_{j^{**}} $ that is second most predictive of $ y_k $.The SAP score for factor $ y_k $ is the difference in prediction scores (e.g., accuracy) between using $ z_{j^*} $ and $ z_{j^{**}} $. The final SAP score is the average of these score differences over all ground-truth factors.A higher SAP score indicates that individual latent dimensions are predictive of individual factors of variation.# Pseudocode for calculating SAP (simplified) # latents: (N_samples, N_latents) - encoded means from VAE # factors: (N_samples, N_factors) - ground-truth factor values # (Assumes factors are discrete for classification_accuracy) from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler def calculate_sap(latents, factors, test_size=0.2, random_state=42): num_latents = latents.shape[1] num_factors = factors.shape[1] sap_scores_per_factor = [] # Split data for training classifiers # Note: A more evaluation would use cross-validation. latents_train, latents_test, factors_train, factors_test = train_test_split( latents, factors, test_size=test_size, random_state=random_state ) # Scale latents (optional but good practice for linear models) scaler = StandardScaler() latents_train_scaled = scaler.fit_transform(latents_train) latents_test_scaled = scaler.transform(latents_test) for k in range(num_factors): # For each ground truth factor y_k y_k_train = factors_train[:, k] y_k_test = factors_test[:, k] prediction_scores = [] for j in range(num_latents): # For each latent z_j z_j_train = latents_train_scaled[:, j].reshape(-1, 1) z_j_test = latents_test_scaled[:, j].reshape(-1, 1) # Train a simple classifier (e.g., Logistic Regression) # Handle cases where y_k has only one class in train/test try: if len(np.unique(y_k_train)) < 2: score = 0.0 # Or handle as appropriate else: model = LogisticRegression(solver='liblinear', multi_class='auto', C=0.1) # Keep model simple model.fit(z_j_train, y_k_train) score = model.score(z_j_test, y_k_test) prediction_scores.append(score) except ValueError: # e.g. if y_k_train has only one class prediction_scores.append(0.0) if not prediction_scores: continue sorted_scores = sorted(prediction_scores, reverse=True) if len(sorted_scores) < 2: continue # Not enough latents # Difference between top two scores sap_k = sorted_scores[0] - sorted_scores[1] sap_scores_per_factor.append(sap_k) return sum(sap_scores_per_factor) / len(sap_scores_per_factor) if sap_scores_per_factor else 0Note on Metric Implementation: The pseudocode above simplifies certain aspects. Implementations require careful handling of data splitting (train/validation/test for the metric classifiers), hyperparameter tuning for the probe classifiers (though typically simple models are preferred to test the inherent predictability from the latent), and potentially averaging results over multiple runs. For dSprites, factor values are discrete, simplifying things.Putting It All Together: An Example WorkflowPrepare Data: Load the dSprites dataset. Separate images and their ground-truth factor labels.Train Models:Train a standard VAE ($ \beta=1 $) as a baseline.Train several $ \beta $-VAEs with increasing $ \beta $ values (e.g., $ \beta \in {2, 4, 8, 16, 32} $).For each model, save the learned encoder.Evaluate Models:For each trained encoder:Pass a held-out test set of dSprites images through the encoder to obtain their latent representations (e.g., the mean $ \mu_z $).Perform latent traversals on a few example images to qualitatively assess disentanglement.Calculate the MIG score using the latent representations and the ground-truth factors of the test set.Calculate the SAP score.Also, record the reconstruction error (e.g., MSE or BCE) on the test set for this model.Analyze Results:Plot the disentanglement scores (MIG, SAP) against the $ \beta $ value.Plot the reconstruction error against the $ \beta $ value.You should observe a trend: higher $ \beta $ often leads to better disentanglement scores (up to a point) but at the cost of higher reconstruction error. This illustrates the trade-off.{"layout": {"title": "Disentanglement vs. Beta in Beta-VAE", "xaxis": {"title": "Beta Value"}, "yaxis": {"title": "Score / Loss", "type": "log"}, "legend": {"x": 0.01, "y": 0.99}}, "data": [{"type": "scatter", "x": [1, 2, 4, 8, 16, 32], "y": [0.15, 0.25, 0.40, 0.55, 0.60, 0.58], "mode": "lines+markers", "name": "MIG Score", "line": {"color": "#228be6"}}, {"type": "scatter", "x": [1, 2, 4, 8, 16, 32], "y": [0.10, 0.18, 0.30, 0.42, 0.45, 0.43], "mode": "lines+markers", "name": "SAP Score", "line": {"color": "#12b886"}}, {"type": "scatter", "x": [1, 2, 4, 8, 16, 32], "y": [0.05, 0.06, 0.08, 0.12, 0.18, 0.25], "mode": "lines+markers", "name": "Reconstruction Loss", "line": {"color": "#fa5252"}, "yaxis": "y2"}], "layout": {"yaxis2": {"title": "Reconstruction Loss", "overlaying": "y", "side": "right", "type": "log"}, "xaxis": {"title": "Beta Value", "type":"log"}, "title": "Disentanglement vs. Beta in Beta-VAE"}}Results showing Mutual Information Gap (MIG) and Separated Attribute Predictability (SAP) scores increasing with $ \beta $, while reconstruction loss also tends to increase. This illustrates the common trade-off in $ \beta $-VAEs. Note the logarithmic scales for better visualization across orders of magnitude.Further ExplorationThis hands-on exercise provides a starting point. You can extend it by:Implementing other metrics: Try your hand at DCI (Disentanglement, Completeness, Informativeness) or Factor-VAE metric, which are more involved but provide different perspectives.Training other models: Implement and evaluate FactorVAE or TCVAE (Total Correlation VAE), which modify the VAE objective differently to encourage disentanglement, as discussed in the chapter. Compare their performance to $ \beta $-VAE.Using different datasets: Experiment with other disentanglement datasets like 3D Shapes, MPI3D, or even try to apply these techniques to more complex datasets (though metric calculation might be harder without ground-truth factors).Investigating limitations: Observe how sensitive metrics are to hyperparameters (number of bins for MIG, classifier complexity for SAP). Reflect on the limitations of current metrics and the ongoing research challenges in defining and achieving "true" disentanglement.By actively engaging with these models and metrics, you'll build a much stronger intuition for the challenges and successes in the field of disentangled representation learning. Remember that this is an active research area, and perfect disentanglement, especially on complex datasets without supervision, remains an open problem.