Skip to content

Commit a02920e

Browse files
committed
refactor: move TCP port helper to shared utils
1 parent b16a5bc commit a02920e

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

src/art/unsloth/service.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import json
77
import logging
88
import os
9-
import socket
109
import subprocess
1110
import sys
1211
from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast
@@ -34,6 +33,7 @@
3433
from ..preprocessing.tokenize import SFTBatch
3534
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
3635
from ..utils.get_model_step import get_step_from_dir
36+
from ..utils.network import find_free_tcp_port
3737
from ..utils.output_dirs import get_step_checkpoint_dir
3838
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
3939
from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train
@@ -208,12 +208,6 @@ def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer:
208208
return optimizer
209209

210210

211-
def _find_free_tcp_port() -> int:
212-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
213-
sock.bind(("127.0.0.1", 0))
214-
return cast(int, sock.getsockname()[1])
215-
216-
217211
def _normalize_merged_checkpoint_name(name: str) -> str:
218212
# PEFT wraps adapted modules under `.base_layer`, but vLLM expects the
219213
# original checkpoint parameter names during update_weights().
@@ -223,6 +217,9 @@ def _normalize_merged_checkpoint_name(name: str) -> str:
223217
return normalized
224218

225219

220+
_find_free_tcp_port = find_free_tcp_port
221+
222+
226223
# ============================================================================
227224
# Model Classes
228225
# ============================================================================
@@ -523,7 +520,7 @@ async def _init_merged_weight_transfer(self) -> None:
523520
) from exc
524521
inference_world_size = int(world_size_response.json()["world_size"])
525522

526-
master_port = _find_free_tcp_port()
523+
master_port = find_free_tcp_port()
527524
init_info = {
528525
"master_address": "127.0.0.1",
529526
"master_port": master_port,

src/art/utils/network.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import socket
2+
from typing import cast
3+
4+
5+
def find_free_tcp_port() -> int:
6+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
7+
sock.bind(("127.0.0.1", 0))
8+
return cast(int, sock.getsockname()[1])

0 commit comments

Comments
 (0)