Revision 68da5bc4-02fc-4cd0-90fa-e42751238bcc - Stack Overflow

I'm not sure `nnx.vmap` and `nnx.split_rngs` are necessary in vfdev's answer. Also, having a `train` kwarg is unnecessary in most situations since NNX models can dynamically jump between train=True, train=False with `.train()` and `.eval()`

```python
import jax
import jax.numpy as jnp
from flax import nnx

class SimpleDropoutModel(nnx.Module):
 def __init__(self, *, rngs: nnx.Rngs):
 """Intializes the model."""
 self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
 self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)

 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
 x = self.linear(x)
 x = self.dropout(x)
 return x

key = jax.random.PRNGKey(42)

model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))

print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)

@nnx.jit
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray):
 """Applies the model to a batch of inputs."""
 return model(x)

bs = 4
batch_input = jnp.ones((bs, 10))

print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")

# Enable training. This works because Dropout layers have a .deterministic property
# that can be modified.
model.train()

output_batch = batched_apply(model, batch_input)

print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)

first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))

if not all_same:
 print("✅ Verification successful: The outputs are different for each sample in the batch.")
 print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
 print("❌ Verification failed: All outputs were the same.")
```

output:

```none
Model initialized successfully.
Dropout Rate: 0.5
------------------------------
Input batch shape: (4, 10)
Input batch:
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
------------------------------
Running the batched model in training mode (dropout is active)...
Output batch shape: (4, 5)

Output batch:
[[0. 0.1736668 0. 0. 0. ]
 [0. 0. 1.6533196 1.0752656 0. ]
 [0. 0. 0. 0. 0.7218913 ]
 [0.09358063 0. 0. 1.0752656 0. ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.
```

and if instead you do `model.eval()`

```none
Output batch:
[[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]]
------------------------------
❌ Verification failed: All outputs were the same.
```

AltStyle によって変換されたページ (->オリジナル) /