Skip to content

Commit 5489fde

Browse files
committed
Load forked adapter weights on first training call
After _experimental_fork_checkpoint, store the checkpoint path on the service. On the first _train_dedicated/_train_shared call, load the adapter weights via load_lora_adapter before training begins. This is needed because create_unsloth_train_context may initialize the LoRA architecture from adapter_config.json without loading the actual trained weights from adapter_model.safetensors, especially when the checkpoint was trained at a different precision than the current load config.
1 parent aa96333 commit 5489fde

2 files changed

Lines changed: 21 additions & 5 deletions

File tree

src/art/local/backend.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,18 +1434,24 @@ async def _experimental_fork_checkpoint(
14341434

14351435
shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir)
14361436

1437-
# Invalidate the UnslothService _state cache so the trainer
1438-
# re-initializes with the forked checkpoint instead of the base model.
1439-
# _state is a cached_property that reads get_last_checkpoint_dir() on
1440-
# first access; if it was accessed before the fork, it cached the base
1441-
# model and will never pick up the forked weights.
1437+
# Ensure the trainer picks up the forked LoRA weights.
1438+
#
1439+
# 1. Invalidate the _state cache so create_unsloth_train_context
1440+
# re-initializes with the forked checkpoint path.
1441+
#
1442+
# 2. Store the forked checkpoint path so the first training call can
1443+
# explicitly load the adapter weights via load_lora_adapter. This
1444+
# is necessary because from_pretrained may set up the LoRA
1445+
# architecture without loading the actual trained weights
1446+
# (especially across precision mismatches).
14421447
service = await self._get_service(cast(TrainableModel, model))
14431448
if hasattr(service, "_state") and "_state" in service.__dict__:
14441449
del service.__dict__["_state"]
14451450
if verbose:
14461451
print(
14471452
"Invalidated UnslothService _state cache to pick up forked checkpoint"
14481453
)
1454+
service._forked_checkpoint_dir = dest_checkpoint_dir
14491455

14501456
if verbose:
14511457
print(

src/art/unsloth/service.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,11 @@ async def _train_dedicated(
588588
verbose: bool = False,
589589
) -> AsyncIterator[dict[str, float]]:
590590
"""Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU."""
591+
# Load forked adapter weights on first training call if needed.
592+
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
593+
if forked_dir is not None:
594+
del self._forked_checkpoint_dir
595+
await self._state.load_lora_adapter(forked_dir)
591596
async for result in run_unsloth_rl_training(
592597
self._state,
593598
disk_packed_tensors=disk_packed_tensors,
@@ -629,6 +634,11 @@ async def _train_shared(
629634
verbose: bool = False,
630635
) -> AsyncIterator[dict[str, float]]:
631636
"""Train in shared mode — sleep/wake cycle with in-process vLLM."""
637+
# Load forked adapter weights on first training call if needed.
638+
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
639+
if forked_dir is not None:
640+
del self._forked_checkpoint_dir
641+
await self._state.load_lora_adapter(forked_dir)
632642
llm = await self.llm
633643

634644
# Pause generation to prevent new requests during training

0 commit comments

Comments
 (0)