66import json
77import logging
88import os
9- import socket
109import subprocess
1110import sys
1211from typing import TYPE_CHECKING , Any , AsyncIterator , Literal , Protocol , cast
3433from ..preprocessing .tokenize import SFTBatch
3534from ..utils .convert_moe_lora import convert_checkpoint_if_needed
3635from ..utils .get_model_step import get_step_from_dir
36+ from ..utils .network import find_free_tcp_port
3737from ..utils .output_dirs import get_step_checkpoint_dir
3838from ..vllm import get_llm , get_worker , openai_server_task , run_on_workers
3939from .train import StopTrainingLoop , gc_and_empty_cuda_cache , train
@@ -208,12 +208,6 @@ def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer:
208208 return optimizer
209209
210210
211- def _find_free_tcp_port () -> int :
212- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
213- sock .bind (("127.0.0.1" , 0 ))
214- return cast (int , sock .getsockname ()[1 ])
215-
216-
217211def _normalize_merged_checkpoint_name (name : str ) -> str :
218212 # PEFT wraps adapted modules under `.base_layer`, but vLLM expects the
219213 # original checkpoint parameter names during update_weights().
@@ -223,6 +217,9 @@ def _normalize_merged_checkpoint_name(name: str) -> str:
223217 return normalized
224218
225219
220+ _find_free_tcp_port = find_free_tcp_port
221+
222+
226223# ============================================================================
227224# Model Classes
228225# ============================================================================
@@ -523,7 +520,7 @@ async def _init_merged_weight_transfer(self) -> None:
523520 ) from exc
524521 inference_world_size = int (world_size_response .json ()["world_size" ])
525522
526- master_port = _find_free_tcp_port ()
523+ master_port = find_free_tcp_port ()
527524 init_info = {
528525 "master_address" : "127.0.0.1" ,
529526 "master_port" : master_port ,
0 commit comments