Skip to content

Use float32 for LoRA weights to avoid the risk of underflow and overflow.#22559

Open
james77777778 wants to merge 3 commits intokeras-team:masterfrom
james77777778:use-float32-for-lora-weights
Open

Use float32 for LoRA weights to avoid the risk of underflow and overflow.#22559
james77777778 wants to merge 3 commits intokeras-team:masterfrom
james77777778:use-float32-for-lora-weights

Conversation

@james77777778
Copy link
Copy Markdown
Contributor

@james77777778 james77777778 commented Mar 27, 2026

Description

As reported in keras-team/keras-hub#2629

We should use high precision (float32) for LoRA weights to stabilize the finetuning.

References:

Contributor Agreement

Please check all boxes below before submitting your PR for review:

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Note: Failing to adhere to this agreement may result in your future PRs no longer being reviewed.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the LoRA implementation across several layers—including Convolutional, Dense, EinsumDense, and Embedding—to ensure that LoRA weights are initialized as float32 to prevent numerical instability. It also introduces explicit casting to the appropriate variable or compute dtypes during kernel composition and forward passes. A critical issue was identified in the EinsumDense layer where a trailing comma incorrectly converts the LoRA update into a tuple, which will cause a TypeError during tensor operations.

Comment thread keras/src/layers/core/einsum_dense.py Outdated
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.96%. Comparing base (e94cb07) to head (5822cef).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #22559   +/-   ##
=======================================
  Coverage   82.95%   82.96%           
=======================================
  Files         596      596           
  Lines       69252    69259    +7     
  Branches    10814    10814           
=======================================
+ Hits        57451    57458    +7     
  Misses       8973     8973           
  Partials     2828     2828           
Flag Coverage Δ
keras 82.78% <100.00%> (+<0.01%) ⬆️
keras-jax 58.72% <100.00%> (+<0.01%) ⬆️
keras-numpy 54.56% <69.23%> (-0.01%) ⬇️
keras-openvino 59.42% <69.23%> (-0.01%) ⬇️
keras-tensorflow 60.29% <100.00%> (+<0.01%) ⬆️
keras-torch 59.06% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@james77777778 james77777778 force-pushed the use-float32-for-lora-weights branch 2 times, most recently from 7722450 to 28a8fe9 Compare March 27, 2026 06:20
@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Mar 27, 2026
@hertschuh hertschuh added the keras-team-review-pending Pending review by a Keras team member. label Mar 31, 2026
@james77777778 james77777778 force-pushed the use-float32-for-lora-weights branch from 28a8fe9 to 4080af9 Compare April 7, 2026 06:51
@james77777778
Copy link
Copy Markdown
Contributor Author

PR rebased. The openvino test failure should be unrelated to this PR.

@james77777778 james77777778 force-pushed the use-float32-for-lora-weights branch from b91abca to 9818dc4 Compare April 18, 2026 11:02
@amitsrivastava78
Copy link
Copy Markdown
Collaborator

Thanks for the PR, it's very well written, please check my minor comment about this.

"lora is already enabled. This can only be done once per layer."
)
self._tracker.unlock()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR itself is fine; just worth adding a note that users should merge LoRA weights before deploying for inference.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The comments have been updated:

        # LoRA weights should be float32 to avoid the risk of underflow or
        # overflow during fine-tuning.
        # When deploying the model, these weights should be merged with the
        # original kernel while maintaining the original kernel's dtype.
        ...

Add notes for deploying with lora weights.
@james77777778 james77777778 force-pushed the use-float32-for-lora-weights branch from 9818dc4 to 5822cef Compare April 20, 2026 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting review keras-team-review-pending Pending review by a Keras team member. size:M stat:awaiting keras-eng Awaiting response from Keras engineer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants