Skip to content

Use jax.random.split for SeedGenerator state advancement#22631

Open
MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
MarcosAsh:jax-rng-split
Open

Use jax.random.split for SeedGenerator state advancement#22631
MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
MarcosAsh:jax-rng-split

Conversation

@MarcosAsh
Copy link
Copy Markdown
Contributor

@MarcosAsh MarcosAsh commented Apr 3, 2026

SeedGenerator.next() previously used a custom counter to advance RNG state. This is a known footgun on JAX where proper key splitting via jax.random.split is needed for statistical independence between successive keys.

This adds a split_seed backend op and refactors SeedGenerator.next() to use it. On JAX it calls jax.random.split for proper PRNG key splitting. Other backends keep the existing counter-based approach since their stateless random ops handle entropy from any unique seed pair.

No public API changes. Non-JAX backends produce identical seed sequences as before.

Fixes #18426

Contributor Agreement

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors SeedGenerator to delegate seed splitting to backend-specific split_seed functions, enabling the JAX backend to use its native jax.random.split mechanism. The update includes corresponding implementations across all backends and new JAX-specific tests for seed independence and dropout mask consistency. Review feedback highlights that the inclusion of cuDNN-optimized LSTM for JAX is out of scope for this PR and should be moved to a separate submission. Furthermore, the OpenVINO split_seed implementation needs to return native OpenVINOKerasTensor objects rather than NumPy arrays to ensure backend consistency and robustness.

Comment thread keras/src/backend/jax/rnn.py
Comment thread keras/src/backend/openvino/random.py Outdated
@MarcosAsh
Copy link
Copy Markdown
Contributor Author

/gemini review

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 3, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.20%. Comparing base (2b176bc) to head (1fb4e34).

Files with missing lines Patch % Lines
keras/src/backend/openvino/random.py 11.11% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22631      +/-   ##
==========================================
+ Coverage   69.26%   75.20%   +5.94%     
==========================================
  Files         596      596              
  Lines       68575    68603      +28     
  Branches    10717    10721       +4     
==========================================
+ Hits        47498    51596    +4098     
+ Misses      18529    14477    -4052     
+ Partials     2548     2530      -18     
Flag Coverage Δ
keras 75.02% <77.77%> (+5.90%) ⬆️
keras-jax 59.33% <36.11%> (-0.02%) ⬇️
keras-numpy 54.97% <47.22%> (-0.02%) ⬇️
keras-tensorflow 60.66% <36.11%> (?)
keras-torch 59.49% <47.22%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified split_seed function across multiple backends (JAX, NumPy, OpenVINO, TensorFlow, and PyTorch) to standardize random seed splitting, replacing the previous counter-based approach in SeedGenerator. The changes also include updated tests for the JAX backend to verify independent seed generation. I have no further feedback as the implementation correctly addresses the need for backend-specific seed splitting logic.

@MarcosAsh MarcosAsh force-pushed the jax-rng-split branch 7 times, most recently from 22d520d to 710fcf4 Compare April 6, 2026 23:34
@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 7, 2026
@hertschuh
Copy link
Copy Markdown
Collaborator

Two comments about this:

  • The first concern about this change is that it breaks backwards compatibility and cross-backend compatibility. I.e. the same code, even when seeded, will no longer produce the same results with the new version of Keras and no longer produce the same results across backend.
  • The other comment is that it doesn't address the core of current rng setup is full of footguns in jax #18426 which is that it is very easy to have hardcoded seeds in JAX without realizing. But fixing this is a deeper change, and probably also a breaking change.

@hertschuh hertschuh added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

current rng setup is full of footguns in jax

5 participants