Reproduction
vllm_serve.py already returns completion_ids as a flat list[list[int]], one token-id list per completion:
|
# Flatten and combine all results |
|
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list |
|
prompt_ids = [output.prompt_token_ids for output in all_outputs] |
|
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] |
|
logprobs, logprob_token_ids = extract_logprobs(all_outputs) |
But OnlineDPOTrainer._generate_vllm_server() flattens that result a second time:
|
# Flatten: each prompt generates 2 completions |
|
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] |
This turns each multi-token completion into one-token completions before decode.
Minimal repro:
from types import SimpleNamespace
import torch
import trl.experimental.online_dpo.online_dpo_trainer as module_under_test
from trl.experimental.online_dpo.online_dpo_trainer import OnlineDPOTrainer
class DummyProcessor:
def __call__(self, text, return_tensors, padding, padding_side, add_special_tokens):
return {"input_ids": torch.tensor([[101, 102, 103]])}
module_under_test.gather_object = lambda x: x
module_under_test.broadcast_object_list = lambda x, from_process=0: x
trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer)
trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0)
trainer.state = SimpleNamespace(global_step=0)
trainer._move_model_to_vllm = lambda: None
trainer._last_loaded_step = -1
trainer.num_generations = 2
trainer.repetition_penalty = 1.0
trainer.temperature = 1.0
trainer.top_p = 1.0
trainer.top_k = None
trainer.min_p = None
trainer.generation_config = SimpleNamespace(max_tokens=16)
trainer.args = SimpleNamespace(generation_kwargs={})
trainer.processing_class = DummyProcessor()
trainer.vllm_client = SimpleNamespace(
generate=lambda **kwargs: {
"completion_ids": [
[11, 12, 13],
[21, 22],
]
}
)
completion_ids, prompt_ids = trainer._generate_vllm_server(["plain prompt", "plain prompt"])
print(completion_ids)
print(prompt_ids)
Actual output:
[[11], [12], [13], [21]]
[[101, 102, 103], [101, 102, 103]]
Expected output:
[[11, 12, 13], [21, 22]]
[[101, 102, 103], [101, 102, 103]]
A minimal fix seems to be removing the second flatten here:
|
# Flatten: each prompt generates 2 completions |
|
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] |
since the server already returns one token-id list per completion.
System Info
Output of trl env in the local environment:
- Platform: Linux-6.5.13-65-650-4141-22041-coreweave-amd64-85c45edc-x86_64-with-glibc2.35
- Python version: 3.12.12
- TRL version: 0.29.1
- PyTorch version: 2.8.0
- accelerator(s): cpu
- Transformers version: 5.5.0
- Accelerate version: 1.11.0
- Accelerate config: not found
- Datasets version: 3.6.0
- HF Hub version: 1.9.0
- bitsandbytes version: 0.48.2
- DeepSpeed version: 0.18.2
- Liger-Kernel version: 0.7.0
- LLM-Blender version: not installed
- OpenAI version: 2.24.0
- PEFT version: 0.18.0
- vLLM version: 0.11.0
Checklist
Reproduction
vllm_serve.pyalready returnscompletion_idsas a flatlist[list[int]], one token-id list per completion:trl/trl/scripts/vllm_serve.py
Lines 631 to 635 in f3e9ac1
But
OnlineDPOTrainer._generate_vllm_server()flattens that result a second time:trl/trl/experimental/online_dpo/online_dpo_trainer.py
Lines 703 to 704 in f3e9ac1
This turns each multi-token completion into one-token completions before decode.
Minimal repro:
Actual output:
Expected output:
A minimal fix seems to be removing the second flatten here:
trl/trl/experimental/online_dpo/online_dpo_trainer.py
Lines 703 to 704 in f3e9ac1
since the server already returns one token-id list per completion.
System Info
Output of
trl envin the local environment:Checklist