Skip to content

Action Heads

The only part of Neon that trains from scratch. Everything else — the video backbone, the Whisper encoder — is frozen. These small, precise decoders convert visual understanding into physical movement.

Parameter Golf v2

Since v2, all action heads use techniques from the Parameter Golf competition: ReLU², RMSNorm, learnable residual scales, U-Net skip connections, and logit soft-capping.


Design: Separate Heads Per Group

Neon uses separate MLP decoders for each joint group rather than one decoder for everything:

graph TD
    F["Fused Features (2048-d)"]
    F --> ARM["Arms Head<br/>ActionChunkingHead<br/>→ (batch, 16, 14)"]
    F --> LOCO["Locomotion Head<br/>ActionChunkingHead<br/>→ (batch, 16, 3)"]
    F --> TH["Torso+Head<br/>MLPDecoder<br/>→ (batch, 3)"]
    F --> LEG["Legs Head<br/>MLPDecoder<br/>→ (batch, 12)"]

    ARM --> CAT["Concatenate"]
    LOCO --> CAT
    TH --> CAT
    LEG --> CAT
    CAT --> OUT["(batch, 16, 17+)"]

    style F fill:#333,color:#fff
    style ARM fill:#1b5e20,color:#fff
    style LOCO fill:#e65100,color:#fff

Why separate?

  1. Different dynamics — Arms need fine motor control (precise grasping). Locomotion needs smooth velocity (no jerking). Different problems, different MLPs.
  2. Progressive training — Start with arms_only, add locomotion later, scale to whole_body. Each mode adds heads without changing existing ones.
  3. Per-group debugging — Monitor arm loss and locomotion loss independently. Know exactly where the model struggles.

The MLPDecoder (v2)

Every layer is enhanced:

graph TD
    IN["Input Features"] --> PROJ["Linear(in, hidden)"]
    PROJ --> N1["RMSNorm"]
    N1 --> A1["ReLU²"]
    A1 --> |"save as skip"| D1["Dropout"]
    D1 --> N2["RMSNorm"]
    N2 --> A2["ReLU²"]
    A2 --> D2["Dropout"]
    D2 --> |"concat skip"| CAT["[h ∥ skip]"]
    CAT --> L3["Linear(2×hidden, hidden)"]
    L3 --> N3["RMSNorm"]
    N3 --> A3["ReLU²"]
    A3 --> RES["h + α · residual"]
    RES --> OUT["Linear(hidden, action_dim)"]
    OUT --> SC["soft_cap(x, 1.0)"]
    SC --> ACTIONS["Actions ∈ (-1, 1)"]

    style A1 fill:#e65100,color:#fff
    style A2 fill:#e65100,color:#fff
    style A3 fill:#e65100,color:#fff
    style SC fill:#1b5e20,color:#fff
decoder = MLPDecoder(
    input_dim=2048,
    output_dim=14,              # 14 arm joints
    hidden_dim=512,
    num_layers=3,
    dropout=0.1,
    use_relu_squared=True,      # ReLU² — smoother
    use_rmsnorm=True,           # RMSNorm — lighter
    use_residual_scale=True,    # Learned α — self-adjusting
    use_skip_connections=True,  # U-Net — gradient highway
    soft_cap_value=1.0,         # Soft-cap — never kills gradients
)

The ActionChunkingHead (v2)

Predicts 16 future timesteps using learnable step embeddings:

class ActionChunkingHead(nn.Module):
    # Input:  (batch, hidden_size)
    # Output: (batch, 16, action_dim)
    #
    # step_embed[0]  → "what to do right now"
    # step_embed[7]  → "what to do in 140ms"
    # step_embed[15] → "what to do in 300ms"
    #
    # Same observation, different temporal perspectives.
    # v2: RMSNorm + ReLU² + soft-cap throughout.

Parameter Counts

For arms_only mode (17 DoF, separate chunking heads):

Component Parameters
Arms ActionChunkingHead ~1.5M
Locomotion ActionChunkingHead ~400K
Fusion layer ~4.2M
Proprioception encoder ~100K
RMSNorm weights ~3K
Residual scales ~10
Total trainable ~6M

The frozen backbone behind them has 3–7 billion. The ratio is 0.08%. This is the power of standing on the shoulders of a video model — you need almost nothing to teach it to act.


Loss Function

MSE loss with optional masking for padded timesteps:

def compute_loss(self, predicted, target, mask=None):
    loss = F.mse_loss(predicted, target, reduction="none")
    if mask is not None:
        loss = loss * mask.unsqueeze(-1)
    return loss.mean()

Per-group loss weighting

If manipulation accuracy matters more than smooth walking, give arms 2× the loss weight. Configurable in training config.


Configuration

from neon.model.action_heads import ActionHeadConfig

config = ActionHeadConfig(
    hidden_size=2048,            # Must match backbone output
    action_dim=17,               # Auto-set from action space
    num_action_steps=16,         # Chunk size
    mlp_hidden=512,              # MLP width
    num_layers=3,                # MLP depth
    dropout=0.1,
    use_separate_heads=True,     # Per-group decoders
    use_action_chunking=True,    # Temporal prediction
    # Parameter Golf v2 — all on by default
    use_relu_squared=True,
    use_rmsnorm=True,
    use_residual_scale=True,
    use_skip_connections=True,
    soft_cap_value=1.0,
)

Backward Compatibility

Disable v2 to get vanilla MLP behavior:

config = ActionHeadConfig(
    use_relu_squared=False,      # → GELU
    use_rmsnorm=False,           # → Identity
    use_residual_scale=False,    # → Standard residuals
    use_skip_connections=False,  # → No skip
    soft_cap_value=0.0,          # → Hard Tanh
)

Advanced Heads

Beyond the standard MLP/Chunking decoders, Neon v2 ships four advanced action head architectures — each for different scenarios.

FlowMatchingHead

Flow-matching for multi-modal action distributions. Derived from π₀ (Physical Intelligence) and rectified flow. Unlike MSE (unimodal), flow matching captures multiple valid action modes — critical when the same visual scene permits several correct trajectories.

graph LR
    F["Features"] --> COND["Condition MLP"]
    NOISE["x₀ ~ N(0,I)"] --> INTERP["x_t = (1-t)·noise + t·action"]
    INTERP --> NET["Velocity Network<br/>MLP + RMSNorm + ReLU²"]
    COND --> NET
    T["t ~ Uniform(0,1)"] --> NET
    NET --> V["v_pred"]
    V --> LOSS["||v_pred - v_target||²"]

    style NET fill:#e65100,color:#fff
    style LOSS fill:#1b5e20,color:#fff
from neon.model.action_heads import FlowMatchingHead

head = FlowMatchingHead(
    input_dim=2048,
    action_dim=17,
    num_steps=16,
    hidden_dim=512,
    num_layers=4,
    num_denoise_steps=4,        # Euler integration steps at inference
    soft_cap_value=1.0,
    noise_beta_alpha=0.0,       # Optional: Beta noise scheduling
    use_cross_attention=False,  # Optional: cross-attend to backbone features
)

# Training: returns flow matching loss
loss = head.compute_loss(features, target_actions, mask)

# Inference: Euler integration from noise → action
actions = head.sample(features, num_denoise_steps=4)  # (batch, 16, 17)

When to use: Tasks with ambiguous action distributions (e.g., "pick up one of the blocks" — any block is valid). ~4M params.


DiTActionHead

Diffusion Transformer inspired by GR00T N1.6. Uses adaptive layer norm (adaLN) conditioning from features + timestep, with DDIM sampling.

graph TD
    F["Features + t"] --> ADALN["adaLN → scale, shift, gate"]
    ADALN --> SA["Self-Attention<br/>over 16 action tokens"]
    SA --> MLP["MLP + ReLU² + RMSNorm"]
    MLP --> RES["Residual + α·skip"]
    RES --> |"× N layers"| SA

    style SA fill:#0097a7,color:#fff
    style MLP fill:#e65100,color:#fff
from neon.model.action_heads import DiTActionHead

head = DiTActionHead(
    input_dim=2048,
    action_dim=17,
    num_steps=16,
    hidden_dim=384,
    num_layers=8,               # 8 DiT blocks (vs GR00T's 32 — stronger backbone)
    num_heads=6,
    num_denoise_steps=4,
    soft_cap_value=1.0,
)

# Training: diffusion loss
loss = head.compute_loss(features, target_actions, mask)

# Inference: DDIM sampling
actions = head.sample(features, num_denoise_steps=4)  # (batch, 16, 17)

When to use: Complex, long-horizon sequences where inter-step dependencies matter. Self-attention across timesteps captures temporal correlations that MLPs miss. ~6M params.


DiTBlock

The building block of DiTActionHead. Each block contains:

  1. adaLN — Adaptive LayerNorm conditioned on features + timestep → scale, shift, gate
  2. Self-Attention — Over action tokens (16 tokens = 16 timesteps)
  3. Cross-Attention (optional) — Action tokens attend to backbone VL features
  4. MLP — ReLU² + RMSNorm with gated residual
from neon.model.action_heads import DiTBlock

block = DiTBlock(
    hidden_dim=384,
    num_heads=6,
    dropout=0.1,
    cross_attention_dim=2048,  # Optional: enable cross-attention to backbone
)

StateRelativeHead

Wrapper that converts any action head to predict Δ from current state instead of absolute positions. Inspired by GR00T N1.6's relative action formulation.

action_absolute = current_state + inner_head(features)

Reduces error accumulation in closed-loop control — small prediction errors stay small instead of drifting.

from neon.model.action_heads import StateRelativeHead, ActionChunkingHead

# Wrap any head
inner = ActionChunkingHead(input_dim=2048, action_dim=17, num_steps=16)
head = StateRelativeHead(inner_head=inner, action_dim=17, num_steps=16)

# Forward pass with current proprioception
actions = head(features, current_state=joint_positions)  # predicts deltas, adds to state

Zero extra parameters — it's a pure wrapper. Use with MLP, Flow, or DiT.


EnsembleHead

Gated ensemble of MLP + Flow + DiT. A learned gate (conditioned on features) decides which head to trust for each input:

graph LR
    F["Features"] --> MLP["MLP Head<br/>~3M params"]
    F --> FLOW["Flow Head<br/>~4M params"]
    F --> DIT["DiT Head<br/>~6M params"]
    F --> GATE["Gate Network<br/>softmax(Linear)"]

    MLP --> MIX["gate₀·MLP + gate₁·Flow + gate₂·DiT"]
    FLOW --> MIX
    DIT --> MIX
    GATE --> MIX
    MIX --> OUT["Actions (16, 17)"]

    style GATE fill:#7b1fa2,color:#fff
    style MIX fill:#1b5e20,color:#fff
from neon.model.action_heads import EnsembleHead

head = EnsembleHead(
    input_dim=2048,
    action_dim=17,
    num_steps=16,
    mlp_hidden=512,
    flow_hidden=512,
    dit_hidden=384,
    dit_layers=8,
    dit_heads=6,
    flow_layers=4,
    num_denoise_steps=4,
)
# Total: ~13M params (MLP ~3M + Flow ~4M + DiT ~6M + gate ~6K)

When to use: Maximum accuracy when you can afford the compute. The gate learns that MLP is fast for simple motions, Flow handles ambiguous actions, DiT excels at complex sequences.


Head Comparison

Head Params Inference Speed Multi-modal Temporal Attention Best For
MLPDecoder ~1.5M Fastest Single-joint groups, edge
ActionChunkingHead ~1.5M Fast Step embeddings Standard, production
FlowMatchingHead ~4M 🔄 4 denoise steps Ambiguous actions
DiTActionHead ~6M 🔄 4 denoise steps ✅ Self-attention Complex sequences
StateRelativeHead +0 Same as inner Depends Depends Closed-loop control
EnsembleHead ~13M 🐢 Slowest Maximum accuracy

Next: Inference & Control — deploy predictions to a real robot