68 questions
- Bountied 0
- Unanswered
- Frequent
- Score
- Trending
- Week
- Month
- Unanswered (my tags)
1
vote
0
answers
53
views
How to vectorize (ensemble) nnx.Modules with separate parameters using nnx.vmap in JAX/Flax
I have a vectorized (ensemble) Q-network implemented using Flax Linen that works as expected. Each critic in the ensemble has separate parameters, and the output is stacked along the first dimension (...
0
votes
0
answers
102
views
Overhead of instantiating a flax model
Is it expensive to keep recreating a Flax network, such as
class QNetwork(nn.Module):
dim: int
@nn.compact
def __call__(self, x):
x = nn.Dense(120)(x)
x = nn.relu(x)
...
joel's user avatar
- 8,132
1
vote
1
answer
104
views
How to control hyperparameter within flax.nnx loss function using an optax.schedule?
from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
if __name__ == '__main__':
shape = (2,55,1)
epochs = 123
rngs = ...
0
votes
1
answer
118
views
JAX shard on GPU and shard on CPU in subroutine, all with JIT
Duplicating my question here: https://github.com/google/flax/discussions/4825
I want to have a JAX or NNX jitted function that consumes and returns GPU-sharded tensors. However, inside the function, I ...
1
vote
2
answers
275
views
Does vmap correctly split the RNG keys?
In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?
import jax
import ...
Jackpap's user avatar
- 8,086
0
votes
1
answer
55
views
Efficient multi-host TPU dataset processing
I want to train LLM on TPUv4-32 using JAX/Flax. The dataset is stored in a mounted google storage bucket. The dataset (Red-Pajama-v2) consists of 5000 shards, which are stored in .json.gz files: ~/...
0
votes
1
answer
100
views
Differentiable weight setting in flax NNX
I'm doing some experiments with Flax NNX (not Linen!).
What I'm trying to do is compute the weights of a network using another network:
A hypernetwork receives some input parameters W and outputs a ...
1
vote
0
answers
77
views
LAPACK Inconsistent across multiple different operating systems and devices
Description
I have a deterministic program that uses jax, and is heavy on linear algebra operations.
I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), ...
2
votes
1
answer
110
views
How to type hint `flax.linen.Module.apply`'s output correctly?
As of writing, this code does not pass the PyRight type checker:
import jax
import jax.numpy as jnp
import jax.typing as jt
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def ...
1
vote
1
answer
305
views
Would using lists rather than jax.numpy arrays lead to more accurate numerical transformations?
I am doing a project with RNNs using jax and flax and I have noticed some behavior that I do not really understand.
My code is basically an optimization loop where the user provides the initial ...
1
vote
1
answer
301
views
Is there a way to update weights of an nnx.Module in Flax's NNX using the lax.scan function?
I have a neural network (nnx.Module) written in Flax's NNX. I want to train this network efficiently using lax.scan instead of a for loop. However, as scan doesn't allow in place changes, how can I ...
1
vote
1
answer
389
views
Freezing filtered parameter collections with Flax.nnx
I'm trying to work out how to do transfer learning with flax.nnx. Below is my attempt to freeze the kernel of my nnx.Linear instance and optimize the bias. I think maybe I'm not correctly setting up ...
2
votes
1
answer
219
views
Debug jax In vscode
Why can't I use a vscode debugger to debug jax code, specifically pure functions. I understand that they provide their own framework for debugging but vscode debugger is quite comfortable. Is this ...
0
votes
1
answer
295
views
Flax nnx / jax: tree.map for layers of incongruent size
I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with ...
1
vote
1
answer
94
views
Jax / Flax potential tracing issue
I'm currently using Flax for neural network implementations. My model takes two inputs:
x and θ. It first processes x through an LSTM, then concatenates the LSTM's output with θ — or more precisely, ...