Skip to content

Commit 0d53531

Browse files
committed
Fix LocalBackend fork to load forked LoRA on both vLLM and trainer
Two fixes after _experimental_fork_checkpoint copies the checkpoint: 1. Overwrite checkpoints/0000 with the forked weights so vLLM loads the correct adapter on startup (it uses @0 by default). 2. Invalidate the UnslothService _state cache so the trainer re-initializes with the forked checkpoint path instead of the base model.
1 parent dc20d8f commit 0d53531

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

src/art/local/backend.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,15 +1434,32 @@ async def _experimental_fork_checkpoint(
14341434

14351435
shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir)
14361436

1437-
# Also overwrite the initial empty checkpoint at step 0 so vLLM
1438-
# loads the forked weights on startup (it uses @0 by default)
1437+
# Overwrite the initial empty checkpoint at step 0 so both vLLM
1438+
# (which loads @0) and the Unsloth trainer (which may have already
1439+
# cached _state from the empty checkpoint) pick up the forked weights.
14391440
step0_dir = get_step_checkpoint_dir(dest_model_dir, 0)
14401441
if os.path.exists(step0_dir) and step0_dir != dest_checkpoint_dir:
14411442
if verbose:
1442-
print(f"Overwriting initial checkpoint at {step0_dir} with forked weights")
1443+
print(
1444+
f"Overwriting initial checkpoint at {step0_dir} with forked weights"
1445+
)
14431446
shutil.rmtree(step0_dir)
14441447
shutil.copytree(dest_checkpoint_dir, step0_dir)
14451448

1449+
# Invalidate the UnslothService _state cache so the trainer
1450+
# re-initializes with the forked checkpoint instead of the base model.
1451+
try:
1452+
service = await self._get_service(cast(TrainableModel, model))
1453+
if "_state" in service.__dict__:
1454+
del service.__dict__["_state"]
1455+
if verbose:
1456+
print(
1457+
"Invalidated UnslothService _state cache "
1458+
"to pick up forked checkpoint"
1459+
)
1460+
except Exception:
1461+
pass
1462+
14461463
if verbose:
14471464
print(
14481465
f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}"

0 commit comments

Comments
 (0)