Use jax.random.split for SeedGenerator state advancement#22631
Use jax.random.split for SeedGenerator state advancement#22631MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors SeedGenerator to delegate seed splitting to backend-specific split_seed functions, enabling the JAX backend to use its native jax.random.split mechanism. The update includes corresponding implementations across all backends and new JAX-specific tests for seed independence and dropout mask consistency. Review feedback highlights that the inclusion of cuDNN-optimized LSTM for JAX is out of scope for this PR and should be moved to a separate submission. Furthermore, the OpenVINO split_seed implementation needs to return native OpenVINOKerasTensor objects rather than NumPy arrays to ensure backend consistency and robustness.
|
/gemini review |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22631 +/- ##
==========================================
+ Coverage 69.26% 75.20% +5.94%
==========================================
Files 596 596
Lines 68575 68603 +28
Branches 10717 10721 +4
==========================================
+ Hits 47498 51596 +4098
+ Misses 18529 14477 -4052
+ Partials 2548 2530 -18
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Code Review
This pull request introduces a unified split_seed function across multiple backends (JAX, NumPy, OpenVINO, TensorFlow, and PyTorch) to standardize random seed splitting, replacing the previous counter-based approach in SeedGenerator. The changes also include updated tests for the JAX backend to verify independent seed generation. I have no further feedback as the implementation correctly addresses the need for backend-specific seed splitting logic.
22d520d to
710fcf4
Compare
|
Two comments about this:
|
SeedGenerator.next() previously used a custom counter to advance RNG state. This is a known footgun on JAX where proper key splitting via jax.random.split is needed for statistical independence between successive keys.
This adds a split_seed backend op and refactors SeedGenerator.next() to use it. On JAX it calls jax.random.split for proper PRNG key splitting. Other backends keep the existing counter-based approach since their stateless random ops handle entropy from any unique seed pair.
No public API changes. Non-JAX backends produce identical seed sequences as before.
Fixes #18426
Contributor Agreement