right now unseeded calls to e.g. keras.random.uniform are going to acquire static seeds at trace time. this has a few undesirable consequences:
- subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step)
- the jax compiler cache will ~never hit, as the constant rng seed values will be different every time
to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng state, which gets included with the training state in model.fit, unseeded rng calls would then do something along the lines of
state.seed, local_seed = jax.random.split(state.seed)
right now unseeded calls to e.g.
keras.random.uniformare going to acquire static seeds at trace time. this has a few undesirable consequences:to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the
Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single globalrngstate, which gets included with the training state inmodel.fit, unseeded rng calls would then do something along the lines of