1

In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?

import jax
import jax.numpy as jnp
from flax import nnx
# --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
# We use nnx.Dropout because it requires random numbers, making it a stateful
# operation that benefits from nnx.vmap's automatic RNG splitting.
class SimpleDropoutModel(nnx.Module):
 def __init__(self, *, rngs: nnx.Rngs):
 """Intializes the model."""
 # The dropout layer needs an RNG stream to generate random masks.
 self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
 self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
 def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
 """Applies the model to a single input."""
 # The `deterministic` flag controls whether dropout is active.
 # We pass `not train` to it.
 x = self.linear(x)
 x = self.dropout(x, deterministic=not train)
 return x
# --- 2. Initialization ---
# Create a PRNG key for reproducibility.
key = jax.random.PRNGKey(42)
# Instantiate the model. NNX requires an `nnx.Rngs` object to manage
# different random number streams (e.g., for 'params' and 'dropout').
# We need to provide an RNG stream for 'params' as well for the Linear layer.
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
# --- 3. Define and Transform the Batched Apply Function ---
# We want to apply our model to a whole batch of data.
# We compose nnx.vmap and nnx.jit to create an efficient, batched function.
# Define a helper function that takes the model, inputs, and train flag.
# Apply nnx.vmap and nnx.jit as decorators.
# Apply vmap first, then jit.
@nnx.vmap(
 in_axes=(None, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
 out_axes=0 # Output is vmapped
)
@nnx.jit(static_argnames=["train"])
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
 """Applies the model to a batch of inputs."""
 # NNX will handle the state and RNGs of the model instance passed to this function.
 return model(x, train=train)
# --- 4. Run the Demonstration ---
# Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
batch_input = jnp.ones((4, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
# Run the JIT-compiled, vmapped function.
# Pass the model instance as the first argument. NNX will handle its state and RNGs.
output_batch = batched_apply(model, batch_input, train=True)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
# --- 5. Verification ---
# Because dropout is random and nnx.vmap correctly split the RNG keys,
# each row in the output batch should be different, even though the inputs were identical.
# We verify that not all outputs are the same.
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
 print("✅ Verification successful: The outputs are different for each sample in the batch.")
 print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
 print("❌ Verification failed: All outputs were the same.")
Mark Rotteveel
110k241 gold badges160 silver badges233 bronze badges
asked Jul 11 at 12:04

2 Answers 2

1

To make dropout work together with vmap in flax, we need to use split_rngs and StateAxes :

import jax
import jax.numpy as jnp
from flax import nnx
# --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
# We use nnx.Dropout because it requires random numbers, making it a stateful
# operation that benefits from nnx.vmap's automatic RNG splitting.
class SimpleDropoutModel(nnx.Module):
 def __init__(self, *, rngs: nnx.Rngs):
 """Intializes the model."""
 # The dropout layer needs an RNG stream to generate random masks.
 self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
 self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
 def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
 """Applies the model to a single input."""
 # The `deterministic` flag controls whether dropout is active.
 # We pass `not train` to it.
 x = self.linear(x)
 x = self.dropout(x, deterministic=not train)
 return x
# --- 2. Initialization ---
# Create a PRNG key for reproducibility.
key = jax.random.PRNGKey(42)
# Instantiate the model. NNX requires an `nnx.Rngs` object to manage
# different random number streams (e.g., for 'params' and 'dropout').
# We need to provide an RNG stream for 'params' as well for the Linear layer.
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
# --- 3. Define and Transform the Batched Apply Function ---
# We want to apply our model to a whole batch of data.
# We compose nnx.vmap and nnx.jit to create an efficient, batched function.
# Define a helper function that takes the model, inputs, and train flag.
# Apply nnx.vmap and nnx.jit as decorators.
# Apply vmap first, then jit.
bs = 4
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=bs, only='dropout')
@nnx.vmap(
 in_axes=(state_axes, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
 out_axes=0 # Output is vmapped
)
@nnx.jit(static_argnames=["train"])
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
 """Applies the model to a batch of inputs."""
 # NNX will handle the state and RNGs of the model instance passed to this function.
 return model(x, train=train)
# --- 4. Run the Demonstration ---
# Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
batch_input = jnp.ones((bs, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
model.train()
# Run the JIT-compiled, vmapped function.
# Pass the model instance as the first argument. NNX will handle its state and RNGs.
output_batch = batched_apply(model, batch_input, train=True)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
# --- 5. Verification ---
# Because dropout is random and nnx.vmap correctly split the RNG keys,
# each row in the output batch should be different, even though the inputs were identical.
# We verify that not all outputs are the same.
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
 print("✅ Verification successful: The outputs are different for each sample in the batch.")
 print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
 print("❌ Verification failed: All outputs were the same.")

Output with jax: 0.7.0.dev20250704, flax: 0.10.6

Output batch:
[[0. 0.1736668 1.6533196 0. 0. ]
 [0. 0. 1.6533196 0. 0.7218913 ]
 [0.09358063 0. 1.6533196 0. 0.7218913 ]
 [0.09358063 0. 1.6533196 0. 0.7218913 ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.
answered Jul 11 at 15:27
Sign up to request clarification or add additional context in comments.

Comments

0

I'm not sure nnx.vmap and nnx.split_rngs are necessary in vfdev's answer. Also, having a train kwarg is unnecessary in most situations since NNX models can dynamically jump between train=True, train=False with .train() and .eval()

import jax
import jax.numpy as jnp
from flax import nnx
class SimpleDropoutModel(nnx.Module):
 def __init__(self, *, rngs: nnx.Rngs):
 """Intializes the model."""
 self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
 self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
 x = self.linear(x)
 x = self.dropout(x)
 return x
key = jax.random.PRNGKey(42)
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
@nnx.jit
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray):
 """Applies the model to a batch of inputs."""
 return model(x)
bs = 4
batch_input = jnp.ones((bs, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
# Enable training. This works because Dropout layers have a .deterministic property
# that can be modified.
model.train()
output_batch = batched_apply(model, batch_input)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
 print("✅ Verification successful: The outputs are different for each sample in the batch.")
 print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
 print("❌ Verification failed: All outputs were the same.")

output:

Model initialized successfully.
Dropout Rate: 0.5
------------------------------
Input batch shape: (4, 10)
Input batch:
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
------------------------------
Running the batched model in training mode (dropout is active)...
Output batch shape: (4, 5)
Output batch:
[[0. 0.1736668 0. 0. 0. ]
 [0. 0. 1.6533196 1.0752656 0. ]
 [0. 0. 0. 0. 0.7218913 ]
 [0.09358063 0. 0. 1.0752656 0. ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.

and if instead you do model.eval()

Output batch:
[[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]]
------------------------------
❌ Verification failed: All outputs were the same.
answered Jul 21 at 16:58

Comments

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.