Alright, let's roll up our sleeves and put theory into practice. In the preceding sections of this chapter, we discussed how Variational Autoencoders can be adapted for sequential and structured data. Now, we'll focus on a hands-on example: implementing a VAE to model and generate sequential data. Specifically, we'll outline the steps to build a Recurrent VAE (RVAE) for character-level text generation. This exercise will solidify your understanding of how to use RNNs within the VAE framework to capture temporal dependencies.
Our goal is to train an RVAE that can learn a compressed representation of text sequences and then use this representation to generate new, plausible text.
We'll work with character-level text generation. This means our model will learn to predict the next character in a sequence given the preceding characters. While word-level models are also common, character-level models are simpler to set up in terms of vocabulary management and can generate novel words or styles.
Corpus Selection: Choose a text corpus. For learning purposes, a moderately sized, coherent text works well. Examples include:
corpus.txt
.Vocabulary Creation: First, we need to determine our vocabulary, which, in this case, is the set of unique characters in the corpus.
# Illustrative Python-like pseudocode
text = open('corpus.txt', 'r').read()
chars = sorted(list(set(text)))
char_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)
Creating Input-Output Sequences:
We need to transform the raw text into sequences that our RVAE can process. We'll use a sliding window approach. For a given sequence_length
, we'll create input sequences and corresponding target sequences (which are typically the input sequence shifted by one character).
# Illustrative Python-like pseudocode
sequence_length = 50 # Example length
data_X = [] # Input sequences
data_y = [] # Target sequences (for decoder reconstruction)
for i in range(0, len(text) - sequence_length, 1):
seq_in = text[i:i + sequence_length]
seq_out = text[i + 1:i + sequence_length + 1] # Target for reconstruction
data_X.append([char_to_int[char] for char in seq_in])
# For RVAE, the decoder will try to reconstruct seq_in,
# or generate seq_out if conditioned on seq_in and z
# For simplicity here, let's assume the decoder aims to reconstruct seq_in
data_y.append([char_to_int[char] for char in seq_in]) # Or seq_out if that's the design
num_sequences = len(data_X)
Note: For a "classic" RVAE aimed at generation from a latent z, the decoder typically reconstructs the input sequence seq_in
. The RNN structure itself handles the sequential prediction.
Data Formatting:
The input data needs to be shaped appropriately for RNNs, typically (num_sequences, sequence_length, feature_dim)
. For character-level models, feature_dim
is often 1 (if using integer inputs directly into an embedding layer) or vocab_size
(if using one-hot encoded inputs). We'll also normalize integer inputs if they are not fed into an embedding layer first.
# Illustrative Python-like pseudocode
# Assuming using embedding layer, so input is (num_sequences, sequence_length)
# X = np.reshape(data_X, (num_sequences, sequence_length))
# y = np.reshape(data_y, (num_sequences, sequence_length))
# PyTorch/TensorFlow will handle batching
Our RVAE will consist of an RNN-based encoder, a latent space sampling mechanism, and an RNN-based decoder.
The encoder's job is to take an input sequence x and map it to the parameters of the approximate posterior distribution q(z∣x), which we assume is a Gaussian N(μz,diag(σz2)).
# PyTorch-like pseudocode for Encoder
# self.embedding = nn.Embedding(vocab_size, embedding_dim)
# self.encoder_rnn = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
# self.fc_mu = nn.Linear(hidden_dim, latent_dim)
# self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# def encode(self, x_sequence):
# embedded = self.embedding(x_sequence) # (batch, seq_len, embedding_dim)
# _, (h_n, _) = self.encoder_rnn(embedded) # h_n is (1, batch, hidden_dim)
# h_n_last_layer = h_n.squeeze(0) # (batch, hidden_dim)
# mu = self.fc_mu(h_n_last_layer)
# logvar = self.fc_logvar(h_n_last_layer)
# return mu, logvar
The h_n
from an LSTM contains the final hidden state for each sequence in the batch.This is standard VAE procedure: z=μz+σz⊙ϵ,where ϵ∼N(0,I) And σz=exp(0.5⋅logσz2). This is the reparameterization trick.
# PyTorch-like pseudocode for reparameterization
# def reparameterize(self, mu, logvar):
# std = torch.exp(0.5 * logvar)
# eps = torch.randn_like(std)
# return mu + eps * std
The decoder takes a sample z from the latent space and aims to reconstruct the original input sequence (or generate a new sequence if z is sampled from the prior p(z)).
# PyTorch-like pseudocode for Decoder
# self.decoder_embedding = nn.Embedding(vocab_size, embedding_dim)
# self.decoder_rnn_cell = nn.LSTMCell(embedding_dim + latent_dim, hidden_dim) # Example: concat z
# # Or: self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim) for h0, c0 init
# self.fc_out = nn.Linear(hidden_dim, vocab_size)
# def decode(self, z, target_sequence, teacher_forcing_ratio=0.5):
# batch_size = z.size(0)
# seq_len = target_sequence.size(1)
#
# # Initialize hidden state (e.g., from z or zeros)
# # hx = self.latent_to_hidden(z) # (batch, hidden_dim)
# # cx = self.latent_to_hidden(z) # (batch, hidden_dim)
# # Or if concatenating z:
# hx = torch.zeros(batch_size, hidden_dim).to(z.device)
# cx = torch.zeros(batch_size, hidden_dim).to(z.device)
#
# # Start token (e.g., embedding of a <SOS> character, or first char of target)
# current_input_char_idx = target_sequence[:, 0] # Example: use first char of target
# outputs = []
#
# for t in range(seq_len):
# embedded_char = self.decoder_embedding(current_input_char_idx) # (batch, embedding_dim)
#
# # Option 1: Concatenate z with each input
# rnn_input = torch.cat((embedded_char, z), dim=1) # (batch, embedding_dim + latent_dim)
# hx, cx = self.decoder_rnn_cell(rnn_input, (hx, cx))
#
# # Option 2: Use z to initialize hx, cx (done before loop)
# # hx, cx = self.decoder_rnn_cell(embedded_char, (hx, cx))
#
# output_logits_t = self.fc_out(hx) # (batch, vocab_size)
# outputs.append(output_logits_t)
#
# use_teacher_force = random.random() < teacher_forcing_ratio
# if use_teacher_force and t < seq_len -1:
# current_input_char_idx = target_sequence[:, t+1]
# else:
# _, top_idx = output_logits_t.topk(1)
# current_input_char_idx = top_idx.squeeze(1).detach() # Use model's own prediction
#
# return torch.stack(outputs, dim=1) # (batch, seq_len, vocab_size)
The RVAE loss function is the standard VAE ELBO, but the reconstruction term is now a sum over the sequence elements.
LRVAE(x,x^,μz,logσz2)=Lrecon+β⋅DKL(q(z∣x)∣∣p(z))Reconstruction Loss (Lrecon): For character-level generation, this is typically the sum of cross-entropy losses between the predicted character distributions and the actual target characters at each position in the sequence.
Lrecon=−t=1∑Tlogp(xt∣x<t,z)In practice, you'd use your framework's CrossEntropyLoss
function, applied across the sequence dimension. Ensure the logits and targets are shaped correctly (e.g., logits: (batch_size * seq_len, vocab_size)
, targets: (batch_size * seq_len)
).
KL Divergence (DKL): The KL divergence between the approximate posterior q(z∣x) and the prior p(z) (usually N(0,I)).
D_{KL}(q(z|x) || p(z)) = -0.5 \sum_{j=1}^{\text{latent_dim}} (1 + \log(\sigma_{z_j}^2) - \mu_{z_j}^2 - \sigma_{z_j}^2)The β term is from β-VAEs and can be used to control the emphasis on disentanglement or reconstruction quality. For a standard VAE, β=1.
The training loop involves:
# Illustrative training step pseudocode
# rvae_model = RVAE(...)
# optimizer = Adam(rvae_model.parameters(), lr=1e-3)
#
# for epoch in range(num_epochs):
# for batch_sequences_x, batch_sequences_y in data_loader:
# optimizer.zero_grad()
#
# mu, logvar = rvae_model.encode(batch_sequences_x)
# z = rvae_model.reparameterize(mu, logvar)
# decoded_logits = rvae_model.decode(z, batch_sequences_y) # batch_sequences_y for teacher forcing
#
# # Reconstruction loss
# # Reshape for CrossEntropyLoss: (Batch * SeqLen, VocabSize) and (Batch * SeqLen)
# recon_loss = criterion_recon(
# decoded_logits.view(-1, vocab_size),
# batch_sequences_y.view(-1)
# )
#
# # KL divergence
# kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# kl_div = kl_div / batch_sequences_x.size(0) # Average over batch
#
# loss = recon_loss + beta * kl_div
# loss.backward()
# optimizer.step()
#
# # Log losses, generate samples periodically
Once the model is trained, you can generate new text sequences:
torch.multinomial
or by taking the argmax
).# Illustrative generation pseudocode
# def generate_sequence(rvae_model, z_sample, start_token_idx, max_len=100):
# rvae_model.eval()
# generated_sequence_indices = [start_token_idx]
# current_input_char_idx = torch.tensor([[start_token_idx]], device=device) # Batch size 1
#
# # Initialize decoder hidden state (from z_sample or zeros if z is concatenated)
# # hx, cx = ... initialized based on z_sample
#
# with torch.no_grad():
# for _ in range(max_len - 1):
# embedded_char = rvae_model.decoder_embedding(current_input_char_idx)
# # rnn_input = torch.cat((embedded_char.squeeze(1), z_sample), dim=1) if concatenating z
# # hx, cx = rvae_model.decoder_rnn_cell(rnn_input, (hx, cx))
# # output_logits = rvae_model.fc_out(hx)
#
# # Simplified: assume a decode_step function in the model
# output_logits, hx, cx = rvae_model.decode_step(current_input_char_idx, z_sample, hx, cx)
#
# # Sample next character (can add temperature for diversity)
# # probabilities = F.softmax(output_logits / temperature, dim=-1)
# # next_char_idx = torch.multinomial(probabilities, 1)
# _, next_char_idx = output_logits.topk(1, dim=-1)
#
# generated_sequence_indices.append(next_char_idx.item())
# current_input_char_idx = next_char_idx
#
# # if next_char_idx.item() == eos_token_idx: break
#
# return "".join([int_to_char[idx] for idx in generated_sequence_indices])
#
# z_prior = torch.randn(1, latent_dim).to(device)
# generated_text = generate_sequence(rvae_model, z_prior, char_to_int['A'])
# print(generated_text)
This practical walkthrough provides a blueprint for implementing VAEs for sequential data. The RVAE is a foundational model, and many extensions and variations exist, such as those incorporating attention mechanisms (which we discussed earlier in this chapter) for handling longer-range dependencies more effectively. Experiment with these components, observe their effects, and consult research papers for more advanced techniques as you tackle more complex sequential modeling tasks.
Was this section helpful?
© 2025 ApX Machine Learning