File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11"""Dedicated LocalBackend smoke test for PipelineTrainer."""
22
33import asyncio
4+ import json
5+ import math
46import os
7+ from pathlib import Path
58import tempfile
69import uuid
710
@@ -163,6 +166,8 @@ async def rollout_fn(
163166 min_batch_size = 1 ,
164167 max_batch_size = 1 ,
165168 max_steps = 2 ,
169+ kl_penalty_coef = 0.25 ,
170+ kl_penalty_reference_step = 0 ,
166171 loss_fn = "cispo" ,
167172 eval_fn = None ,
168173 )
@@ -180,5 +185,23 @@ async def rollout_fn(
180185 model_ids = [m .id async for m in client .models .list ()]
181186 assert f"{ model .name } @0" in model_ids
182187 assert f"{ model .name } @{ latest_step } " in model_ids
188+
189+ history_path = (
190+ Path (tmpdir )
191+ / model .project
192+ / "models"
193+ / model .name
194+ / "history.jsonl"
195+ )
196+ history_rows = [
197+ json .loads (line ) for line in history_path .read_text ().splitlines ()
198+ ]
199+ kl_values = [
200+ row ["loss/kl_policy_ref" ]
201+ for row in history_rows
202+ if "loss/kl_policy_ref" in row
203+ ]
204+ assert kl_values
205+ assert all (math .isfinite (value ) for value in kl_values )
183206 finally :
184207 await client .close ()
You can’t perform that action at this time.
0 commit comments