Skip to content

Commit 7871978

Browse files
committed
Add dtype harmonisation patch for mixed bf16/fp16 LoRA on H200
On H200 GPUs, base model activations run in bf16 while LoRA adapter weights are fp16. Unsloth's fused matmul_lora and fast_linear_forward call addmm_/addmv_ which crash on mixed dtypes. This patch casts tensors to a common dtype before those ops. Applied automatically when UnslothService._state is first accessed.
1 parent 62e4fbc commit 7871978

2 files changed

Lines changed: 172 additions & 0 deletions

File tree

src/art/unsloth/dtype_patch.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Patch Unsloth's fused LoRA kernels to handle mixed bf16/fp16 dtypes.
2+
3+
On certain GPU accelerators (e.g. H200), base model activations run in bf16
4+
while LoRA adapter weights remain in fp16. Unsloth's ``matmul_lora`` and
5+
``fast_linear_forward`` call ``addmm_`` / ``addmv_`` which require matching
6+
dtypes, causing a RuntimeError. This module patches those functions to cast
7+
tensors to a common dtype before the fused ops.
8+
9+
Apply once at startup via :func:`ensure_dtype_patch`.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import logging
15+
from typing import Any, Callable
16+
17+
_PATCHED = False
18+
19+
20+
def _cast_if_needed(tensor: Any, dtype: Any) -> Any:
21+
if tensor is None:
22+
return None
23+
if getattr(tensor, "dtype", None) == dtype:
24+
return tensor
25+
try:
26+
return tensor.to(dtype)
27+
except AttributeError:
28+
return tensor
29+
30+
31+
def ensure_dtype_patch(log: logging.Logger | None = None) -> bool:
32+
"""Patch Unsloth LoRA helpers for mixed-precision safety. Idempotent."""
33+
global _PATCHED
34+
if _PATCHED:
35+
return True
36+
37+
try:
38+
import torch
39+
import unsloth.kernels.utils as utils
40+
except ImportError:
41+
if log:
42+
log.debug("Unsloth not available; skipping dtype patch.")
43+
return False
44+
45+
Float8Tensor = getattr(utils, "Float8Tensor", None)
46+
torch_matmul: Callable[..., Any] = utils.torch_matmul
47+
fast_dequantize: Callable[..., Any] = utils.fast_dequantize
48+
fp8_linear: Callable[..., Any] | None = getattr(utils, "fp8_linear", None)
49+
fast_gemv: Callable[..., Any] | None = getattr(utils, "fast_gemv", None)
50+
torch_mm: Callable[..., Any] = utils.torch_mm
51+
torch_mv: Callable[..., Any] = utils.torch_mv
52+
get_lora_parameters_bias: Callable[..., Any] = utils.get_lora_parameters_bias
53+
54+
original_fast_linear_forward = utils.fast_linear_forward
55+
original_matmul_lora = utils.matmul_lora
56+
57+
bf16 = torch.bfloat16
58+
59+
def _target_dtype(out_tensor: Any, hidden_dtype: Any) -> Any:
60+
if hidden_dtype == bf16:
61+
return bf16
62+
if out_tensor is not None:
63+
return out_tensor.dtype
64+
return hidden_dtype
65+
66+
def patched_matmul_lora(
67+
X: Any,
68+
W: Any,
69+
W_quant: Any,
70+
A: Any,
71+
B: Any,
72+
s: Any,
73+
out: Any = None,
74+
) -> Any:
75+
dtype = X.dtype
76+
reshape = False
77+
if X.dim() == 3:
78+
batch, seq_len, _ = X.shape
79+
X = X.view(-1, X.shape[-1])
80+
reshape = True
81+
82+
if Float8Tensor is not None and isinstance(W, Float8Tensor):
83+
if W.ndim != 2:
84+
raise ValueError("Expected 2D Float8Tensor for LoRA matmul.")
85+
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
86+
W_full = W.dequantize()
87+
else:
88+
W_full = W.contiguous()
89+
out = torch_matmul(X, W_full.t(), out=out)
90+
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
91+
if fp8_linear is None:
92+
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
93+
out = fp8_linear(X, W, W_quant)
94+
else:
95+
W_full = fast_dequantize(W, W_quant, use_global_buffer=True)
96+
out = torch_matmul(X, W_full.t(), out=out)
97+
98+
if A is not None:
99+
td = _target_dtype(out, dtype)
100+
XA = torch_matmul(_cast_if_needed(X, td), _cast_if_needed(A.t(), td))
101+
out = _cast_if_needed(out, td)
102+
out = out.addmm_(XA, _cast_if_needed(B.t(), td), alpha=s)
103+
104+
return out.view(batch, seq_len, -1) if reshape else out
105+
106+
def patched_fast_linear_forward(
107+
proj: Any, X: Any, temp_lora: Any = None, out: Any = None
108+
) -> Any:
109+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
110+
bsz, q_len, in_dim = X.shape
111+
112+
if q_len != 1:
113+
return patched_matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
114+
115+
if W_quant is None:
116+
out = torch_matmul(X, W.t(), out=out)
117+
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
118+
if fp8_linear is None:
119+
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
120+
out = fp8_linear(X, W, W_quant, bias)
121+
elif fast_gemv is not None and bsz == 1 and q_len == 1:
122+
out = fast_gemv(X, W, W_quant, out=out)
123+
else:
124+
W_full = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
125+
out = torch_matmul(X, W_full, out=out)
126+
127+
if lora_A is not None:
128+
td = _target_dtype(out, X.dtype)
129+
if (
130+
not hasattr(lora_A, "_fast_lora")
131+
or getattr(lora_A._fast_lora, "dtype", None) != td
132+
):
133+
lora_A._fast_lora = lora_A.to(td)
134+
lora_B._fast_lora = lora_B.to(td)
135+
136+
X_lora = _cast_if_needed(X, td)
137+
out = _cast_if_needed(out, td)
138+
out_dim = out.shape[2]
139+
140+
if bsz == 1:
141+
out = out.view(out_dim)
142+
temp_lora = torch_mv(lora_A._fast_lora, X_lora.ravel(), out=temp_lora)
143+
out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)
144+
out = out.view(1, 1, out_dim)
145+
else:
146+
out = out.view(bsz, out_dim)
147+
temp_lora = torch_mm(
148+
X_lora.view(bsz, in_dim),
149+
lora_A._fast_lora.t(),
150+
out=temp_lora,
151+
)
152+
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)
153+
out = out.view(bsz, 1, out_dim)
154+
155+
if bias is not None:
156+
out = out + _cast_if_needed(bias, out.dtype)
157+
158+
return out
159+
160+
utils.matmul_lora = patched_matmul_lora # type: ignore[assignment]
161+
utils.fast_linear_forward = patched_fast_linear_forward # type: ignore[assignment]
162+
utils._original_fast_linear_forward = original_fast_linear_forward # type: ignore[attr-defined]
163+
utils._original_matmul_lora = original_matmul_lora # type: ignore[attr-defined]
164+
165+
_PATCHED = True
166+
if log:
167+
log.debug("Applied Unsloth LoRA dtype harmonisation patch.")
168+
return True

src/art/unsloth/service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,10 @@ async def train_sft(
822822

823823
@cached_property
824824
def _state(self) -> UnslothTrainContext:
825+
from .dtype_patch import ensure_dtype_patch
826+
827+
ensure_dtype_patch()
828+
825829
init_args = dict(self.config.get("init_args", {}))
826830
checkpoint_dir = get_last_checkpoint_dir(self.output_dir)
827831
if checkpoint_dir:

0 commit comments

Comments
 (0)