Skip to content

current rng setup is full of footguns in jax #18426

@GallagherCommaJack

Description

@GallagherCommaJack

right now unseeded calls to e.g. keras.random.uniform are going to acquire static seeds at trace time. this has a few undesirable consequences:

  1. subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step)
  2. the jax compiler cache will ~never hit, as the constant rng seed values will be different every time

to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng state, which gets included with the training state in model.fit, unseeded rng calls would then do something along the lines of

state.seed, local_seed = jax.random.split(state.seed)

Metadata

Metadata

Assignees

Labels

type:featureThe user is asking for a new feature.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions