Skip to content

OnlineDPOTrainer._generate_vllm_server() flattens vllm-serve completion_ids twice #5514

@JohnGiorgi

Description

@JohnGiorgi

Reproduction

vllm_serve.py already returns completion_ids as a flat list[list[int]], one token-id list per completion:

# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs, logprob_token_ids = extract_logprobs(all_outputs)

But OnlineDPOTrainer._generate_vllm_server() flattens that result a second time:

# Flatten: each prompt generates 2 completions
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]

This turns each multi-token completion into one-token completions before decode.

Minimal repro:

from types import SimpleNamespace

import torch

import trl.experimental.online_dpo.online_dpo_trainer as module_under_test
from trl.experimental.online_dpo.online_dpo_trainer import OnlineDPOTrainer


class DummyProcessor:
    def __call__(self, text, return_tensors, padding, padding_side, add_special_tokens):
        return {"input_ids": torch.tensor([[101, 102, 103]])}


module_under_test.gather_object = lambda x: x
module_under_test.broadcast_object_list = lambda x, from_process=0: x

trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer)
trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0)
trainer.state = SimpleNamespace(global_step=0)
trainer._move_model_to_vllm = lambda: None
trainer._last_loaded_step = -1
trainer.num_generations = 2
trainer.repetition_penalty = 1.0
trainer.temperature = 1.0
trainer.top_p = 1.0
trainer.top_k = None
trainer.min_p = None
trainer.generation_config = SimpleNamespace(max_tokens=16)
trainer.args = SimpleNamespace(generation_kwargs={})
trainer.processing_class = DummyProcessor()
trainer.vllm_client = SimpleNamespace(
    generate=lambda **kwargs: {
        "completion_ids": [
            [11, 12, 13],
            [21, 22],
        ]
    }
)

completion_ids, prompt_ids = trainer._generate_vllm_server(["plain prompt", "plain prompt"])

print(completion_ids)
print(prompt_ids)

Actual output:

[[11], [12], [13], [21]]
[[101, 102, 103], [101, 102, 103]]

Expected output:

[[11, 12, 13], [21, 22]]
[[101, 102, 103], [101, 102, 103]]

A minimal fix seems to be removing the second flatten here:

# Flatten: each prompt generates 2 completions
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]

since the server already returns one token-id list per completion.

System Info

Output of trl env in the local environment:

- Platform: Linux-6.5.13-65-650-4141-22041-coreweave-amd64-85c45edc-x86_64-with-glibc2.35
- Python version: 3.12.12
- TRL version: 0.29.1
- PyTorch version: 2.8.0
- accelerator(s): cpu
- Transformers version: 5.5.0
- Accelerate version: 1.11.0
- Accelerate config: not found
- Datasets version: 3.6.0
- HF Hub version: 1.9.0
- bitsandbytes version: 0.48.2
- DeepSpeed version: 0.18.2
- Liger-Kernel version: 0.7.0
- LLM-Blender version: not installed
- OpenAI version: 2.24.0
- PEFT version: 0.18.0
- vLLM version: 0.11.0

Checklist

  • I have checked that my issue isn't already filed
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible
  • Any code provided is properly formatted in code blocks
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions