Skip to content

Commit d7a9a06

Browse files
committed
fix: Respect max batch size in PipelineTrainer
1 parent c656608 commit d7a9a06

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/art/pipeline_trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ async def _collect_batch(
552552
continue
553553
batch.append(item)
554554

555-
while not saw_sentinel:
555+
while not saw_sentinel and len(batch) < self.max_batch_size:
556556
try:
557557
item = self._output_queue.get_nowait()
558558
except asyncio.QueueEmpty:

0 commit comments

Comments
 (0)