Skip to main content
Stack Overflow
  1. About
  2. For Teams

Return to Revisions

2 of 2
added 6 characters in body
David Braun
  • 820
  • 1
  • 11
  • 18

I'm not sure nnx.vmap and nnx.split_rngs are necessary in vfdev's answer. Also, having a train kwarg is unnecessary in most situations since NNX models can dynamically jump between train=True, train=False with .train() and .eval()

import jax
import jax.numpy as jnp
from flax import nnx
class SimpleDropoutModel(nnx.Module):
 def __init__(self, *, rngs: nnx.Rngs):
 """Intializes the model."""
 self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
 self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
 x = self.linear(x)
 x = self.dropout(x)
 return x
key = jax.random.PRNGKey(42)
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
@nnx.jit
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray):
 """Applies the model to a batch of inputs."""
 return model(x)
bs = 4
batch_input = jnp.ones((bs, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
# Enable training. This works because Dropout layers have a .deterministic property
# that can be modified.
model.train()
output_batch = batched_apply(model, batch_input)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
 print("✅ Verification successful: The outputs are different for each sample in the batch.")
 print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
 print("❌ Verification failed: All outputs were the same.")

output:

Model initialized successfully.
Dropout Rate: 0.5
------------------------------
Input batch shape: (4, 10)
Input batch:
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
------------------------------
Running the batched model in training mode (dropout is active)...
Output batch shape: (4, 5)
Output batch:
[[0. 0.1736668 0. 0. 0. ]
 [0. 0. 1.6533196 1.0752656 0. ]
 [0. 0. 0. 0. 0.7218913 ]
 [0.09358063 0. 0. 1.0752656 0. ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.

and if instead you do model.eval()

Output batch:
[[0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]
 [0.04679031 0.0868334 0.8266598 0.5376328 0.36094564]]
------------------------------
❌ Verification failed: All outputs were the same.
David Braun
  • 820
  • 1
  • 11
  • 18

AltStyle によって変換されたページ (->オリジナル) /