researchengineer.ing

February 19, 2026

Phase 1

World Models 101

A ground-up reading of Ha & Schmidhuber's World Models — the MDN-RNN memory architecture, the dream environment trick, and a minimal reimplementation in PyTorch.

ArchaeologistReviewerHacker
#World Models#Reinforcement Learning#Generative Deep Learning
Research TasteMathematical Rigor & ReasoningImplementation Depth
paper:Ha & Schmidhuber, 2018 — World Models. NeurIPS 2018.[github →]

What I Built

A careful read and partial reimplementation of Ha & Schmidhuber's 2018 World Models paper. The goal: understand the three-component architecture (Vision → Memory → Controller) from first principles, trace why each design choice was made, and get the VAE + MDN-RNN pipeline running on CarRacing-v2.

This is not a reproduction of their full results. It's a dissection — what does each component actually do, what breaks if you remove it, and what the paper leaves unsaid.


The Architecture

World Models decomposes the agent into three modules trained largely independently:

V — Vision (Variational Autoencoder)

Compresses each raw frame xtR64×64×3x_t \in \mathbb{R}^{64 \times 64 \times 3} to a latent vector ztR32z_t \in \mathbb{R}^{32} via:

qϕ(ztxt)=N(μϕ(xt),diag(σϕ2(xt)))q_\phi(z_t | x_t) = \mathcal{N}(\mu_\phi(x_t),\, \text{diag}(\sigma^2_\phi(x_t)))

The encoder outputs μ\mu and logσ2\log \sigma^2; we sample zt=μ+σϵz_t = \mu + \sigma \cdot \epsilon, ϵN(0,I)\epsilon \sim \mathcal{N}(0, I). The VAE loss is the standard ELBO:

LVAE=Eqϕ[logpθ(xtzt)]βDKL(qϕ(ztxt)N(0,I))\mathcal{L}_\text{VAE} = \mathbb{E}_{q_\phi}[\log p_\theta(x_t | z_t)] - \beta \cdot D_{\text{KL}}(q_\phi(z_t | x_t) \| \mathcal{N}(0, I))

The β\beta term is doing real work here. Too high → blurry reconstructions but a well-structured latent space. Too low → sharp reconstructions but the latent space has no useful geometry for the controller.

M — Memory (MDN-RNN)

A recurrent model that predicts the distribution over the next latent, not just the next latent itself:

p(zt+1at,zt,ht)=k=1KπkN(zt+1;μk,σk2I)p(z_{t+1} | a_t, z_t, h_t) = \sum_{k=1}^{K} \pi_k \cdot \mathcal{N}(z_{t+1};\, \mu_k,\, \sigma_k^2 I)

where hth_t is the LSTM hidden state. The Mixture Density Network (MDN) output head replaces the usual single Gaussian with a KK-component mixture, which is critical for modeling the multimodal uncertainty in environment transitions (e.g., a car can skid left or right from a given state).

The loss is negative log-likelihood under the mixture:

LMDN=logk=1KπkN(zt+1;μk,σk2I)\mathcal{L}_\text{MDN} = -\log \sum_{k=1}^{K} \pi_k \cdot \mathcal{N}(z_{t+1};\, \mu_k,\, \sigma_k^2 I)

C — Controller (Linear)

The controller maps (zt,ht)(z_t, h_t) to action directly — a single linear layer with no hidden units:

at=Wc[zt;ht]+bca_t = W_c [z_t; h_t] + b_c

This is deliberately minimal. The idea: if V and M are doing their jobs, the controller doesn't need to be complex. The world model learns the representation; the controller learns to exploit it. The two are decoupled so CMA-ES can optimize C without backpropagating through V or M.


Implementation

import torch
import torch.nn as nn

# --- VAE ---
class Encoder(nn.Module):
    def __init__(self, z_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2), nn.ReLU(),
        )
        self.fc_mu = nn.Linear(2 * 2 * 256, z_dim)
        self.fc_logvar = nn.Linear(2 * 2 * 256, z_dim)

    def forward(self, x):
        h = self.conv(x).flatten(1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        return mu + std * torch.randn_like(std)

# --- MDN-RNN ---
class MDNRNN(nn.Module):
    def __init__(self, z_dim=32, a_dim=3, h_dim=256, n_mix=5):
        super().__init__()
        self.lstm = nn.LSTMCell(z_dim + a_dim, h_dim)
        self.mdn = nn.Linear(h_dim, n_mix * (1 + 2 * z_dim))
        self.n_mix = n_mix
        self.z_dim = z_dim

    def forward(self, z, a, h, c):
        inp = torch.cat([z, a], dim=-1)
        h, c = self.lstm(inp, (h, c))
        out = self.mdn(h)
        pi, mu, log_sigma = out.split(
            [self.n_mix, self.n_mix * self.z_dim, self.n_mix * self.z_dim], dim=-1
        )
        pi = pi.softmax(-1)
        mu = mu.view(-1, self.n_mix, self.z_dim)
        sigma = log_sigma.view(-1, self.n_mix, self.z_dim).exp()
        return pi, mu, sigma, h, c

The Dream Environment Trick

The most interesting architectural decision in the paper: once V and M are trained, you can train the controller entirely inside the world model's imagination — no environment calls needed.

The LSTM's hidden state hth_t encodes a summary of everything seen so far. At each step, instead of receiving a real frame, the controller receives (zt,ht)(z_t, h_t), takes an action, and the MDN-RNN samples the next latent zt+1z_{t+1}. The controller is optimized against the simulated reward.

This works because the controller is linear — it needs very few parameters, so the biases introduced by the imperfect world model don't catastrophically mislead the optimization. The authors explicitly note that a more powerful controller would overfit to the dream's inaccuracies.


Experiments

Can the VAE latent space preserve task-relevant structure?

I trained the VAE on 10k random rollouts from CarRacing-v2. The reconstruction quality was decent at β=1\beta = 1 but the latent space had poor geometry (no clear separation between track/grass/car). At β=4\beta = 4 the reconstructions blurred but the latent space clustered meaningfully — similar road orientations mapped to nearby zz vectors.

Does the MDN matter vs. a single Gaussian?

Replaced the MDN output with a single Gaussian (K=1K=1). The RNN trained fine on average prediction but badly underestimated uncertainty in stochastic transitions (lane changes, skids). The mixture model captured the bimodal distributions that appeared at decision points. With K=1K=1, the controller downstream over-committed to one branch.


Failure Log

Reconstruction loss dominated early training. The KL term vanished (DKL0D_\text{KL} \approx 0, posterior collapsed to prior) for the first ~5k steps. Fixed by KL annealing: ramping β\beta from 0 to 1 over 10k steps.

MDN numerical instability. Log-sum-exp needed to be done carefully:

# unstable: log(sum(pi * N(z; mu, sigma)))
# stable: use torch.logsumexp
log_prob = torch.logsumexp(log_pi + gaussian_log_prob, dim=1)

Without this, loss went NaN within 100 steps on GPU.

The controller evaluation loop in "dream" mode diverged. The MDN-RNN samples a next latent, but with no temperature control the samples drifted into out-of-distribution regions after ~50 steps — the generated "frames" became incoherent and the reward model was useless. Fixed by capping the MDN-RNN temperature at 0.5 during controller rollouts (consistent with the paper's appendix, which I missed on the first read).


What the Paper Doesn't Say

  • The VAE training rollouts are collected with random actions. This means the learned latent space only covers states reachable by a random policy — which may miss task-critical configurations. If the optimal policy visits rare states (tight corners at speed), V may never have seen them.

  • The CMA-ES optimization of C assumes the fitness landscape is unimodal. For environments with multiple locally optimal strategies, this is a strong assumption. The paper works around this by running 16 parallel CMA-ES populations and taking the best.

  • Dream-mode training amplifies model errors. If the MDN-RNN confidently predicts a wrong transition, the controller will optimize for that wrong world. The paper partially addresses this with an early-termination heuristic but doesn't analyze the failure mode systematically.


Researcher Hat — Extensions

The most interesting open question: what happens when V and M are trained jointly with a task signal, rather than independently? The paper deliberately decouples them to keep training tractable, but joint training should produce latent representations more aligned with what the controller actually needs.

Dreamer (Hafner et al., 2019) is exactly this: they replace the independent V+M with a jointly-trained RSSM (Recurrent State Space Model) and backpropagate through imagined rollouts. Worth reading as a direct successor.

Also: the MDN-RNN is predicting p(zt+1zt,at,ht)p(z_{t+1} | z_t, a_t, h_t), but it's not predicting reward. Dreamer adds a reward model alongside the transition model, so the controller can directly optimize imagined cumulative reward. That's the piece missing here.


Time Log

  • Paper read (first pass, skimming): 1 hour
  • Paper read (second pass, derivations): 2 hours
  • VAE implementation + debugging: 3 hours
  • MDN-RNN implementation + NaN debugging: 2.5 hours
  • Dream loop + controller stub: 2 hours
  • Writing this up: 1.5 hours

Total: ~12 hours. The MDN numerical stability issue ate a full afternoon.