I'm not sure nnx.vmap and nnx.split_rngs are necessary in vfdev's answer. Also, having a train kwarg is unnecessary in most situations since NNX models can dynamically jump between train=True, train=False with .train() and .eval()
import jax
import jax.numpy as jnp
from flax import nnx
class SimpleDropoutModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
"""Intializes the model."""
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = self.linear(x)
x = self.dropout(x)
return x
key = jax.random.PRNGKey(42)
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
@nnx.jit
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray):
"""Applies the model to a batch of inputs."""
return model(x)
bs = 4
batch_input = jnp.ones((bs, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
# Enable training. This works because Dropout layers have a .deterministic property
# that can be modified.
model.train()
output_batch = batched_apply(model, batch_input)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
print("✅ Verification successful: The outputs are different for each sample in the batch.")
print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
print("❌ Verification failed: All outputs were the same.")
output:
Model initialized successfully.
Dropout Rate: 0.5
------------------------------
Input batch shape: (4, 10)
Input batch:
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
------------------------------
Running the batched model in training mode (dropout is active)...
Output batch shape: (4, 5)
Output batch:
[[0. 0.1736668 0. 0. 0. ]
[0. 0. 1.6533196 1.0752656 0. ]
[0. 0. 0. 0. 0.7218913 ]
[0.09358063 0. 0. 1.0752656 0. ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.
and if instead you do model.eval()
Output batch:
[[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]]
------------------------------
❌ Verification failed: All outputs were the same.
David Braun
- 820
- 1
- 11
- 18