Skip to content

FSDP + torch.nn.Parameter (MoE layer) lora fine-tuning doesn't work #3080

@dskhudia

Description

@dskhudia

System Info

ParamWrapper.get_delta_weight doesn't account for FSDP-sharded lora_A/lora_B parameters. The fix would likely involve either calling torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params() around the weight access, or restructuring the delta computation to work within the FSDP all-gather lifecycle.

Error:

[rank6]:   File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2228, in forward
[rank6]:     with self._activate_lora(self.active_adapters):
[rank6]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File ".local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 137, in __enter__
[rank6]:     return next(self.gen)
[rank6]:            ^^^^^^^^^^^^^^
[rank6]:   File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2104, in _activate_lora
[rank6]:     delta_weight = self.get_delta_weight(active_adapter)
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2081, in get_delta_weight
[rank6]:     weight_A = weight_A.reshape(self.num_experts, -1, weight_A.shape[-1])
[rank6]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: shape '[128, -1, 368640]' is invalid for input of size 368640

Who can help?

@BenjaminBossan because #2638

Reproduction

On 8x H100

pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
pip install datasets accelerate
pip install hf_transfer
pip install "datasets==3.2.0" --force-reinstall

script to repro:

#!/usr/bin/env python3
"""
Fine-tune OpenAI GPT-OSS 120B with TRL SFTTrainer + FSDP + LoRA on 8 H100 GPUs.

Usage:
    torchrun --nproc_per_node=8 train_gpt_oss_120b.py
"""

import os
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model


def infer_transformer_blocks_for_fsdp(model):
    """Auto-detect transformer block class names for FSDP wrapping."""
    COMMON = {
        "LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
        "Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
        "GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
        "DecoderLayer", "GPTJBlock", "OPTDecoderLayer",
    }
    hits = set()
    for _, m in model.named_modules():
        name = m.__class__.__name__
        if name in COMMON:
            hits.add(name)
    # Fallback: grab anything that looks like a decoder block
    if not hits:
        for _, m in model.named_modules():
            name = m.__class__.__name__
            if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
                hits.add(name)
    return sorted(hits)


def main():
    # ---------- DDP / CUDA binding ----------
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    os.environ.setdefault("NCCL_DEBUG", "WARN")
    os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
    os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")

    # ---------- Config ----------
    MODEL_NAME = "openai/gpt-oss-120b"
    MAX_LENGTH = 2048
    PER_DEVICE_BATCH = 1
    GRAD_ACCUM = 4
    LR = 1.5e-4
    EPOCHS = 1
    OUTPUT_DIR = "/tmp/gpt-oss-120b-finetune"

    is_main = int(os.environ.get("RANK", "0")) == 0
    world_size = int(os.environ.get("WORLD_SIZE", "1"))

    if is_main:
        print("=" * 60)
        print("FSDP (full_shard) launch for 120B")
        print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
        print("=" * 60)

    # ---------- Tokenizer ----------
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = MAX_LENGTH
    tokenizer.truncation_side = "right"

    # ---------- Model ----------
    # No device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
        use_cache=False,
        low_cpu_mem_usage=True,
    )

    # ---------- LoRA ----------
    #peft_config = LoraConfig(
    #    r=16,
    #    lora_alpha=16,
    #    target_modules="all-linear",
    #    lora_dropout=0.05,
    #    bias="none",
    #    task_type="CAUSAL_LM",
    #)
    peft_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules="all-linear",
        rank_pattern={
            "mlp.experts.gate_up_proj": 8,
            "mlp.experts.down_proj": 8
        },
        target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, peft_config)
    if is_main:
        model.print_trainable_parameters()

    # ---------- Data ----------
    dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
    if is_main:
        print(f"Dataset size: {len(dataset)}")

    # ---------- FSDP settings ----------
    fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
    if not fsdp_wrap_classes:
        raise RuntimeError(
            "Could not infer transformer block classes for FSDP wrapping; "
            "print(model) and add the block class explicitly."
        )
    if is_main:
        print(f"FSDP wrapping classes: {fsdp_wrap_classes}")

    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=5,
        logging_strategy="steps",
        save_strategy="no",
        report_to="none",
        ddp_find_unused_parameters=False,
        dataloader_pin_memory=True,
        max_length=MAX_LENGTH,
        gradient_checkpointing=False,
        # ---- FSDP knobs ----
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
            "activation_checkpointing": True,
            "activation_checkpointing_reentrant": False,
            "xla": False,
            "limit_all_gathers": True,
            "use_orig_params": True,
            "sync_module_states": True,
        },
    )

    # ---------- Trainer ----------
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    # Verify distributed init & FSDP
    rank = int(os.getenv("RANK", "0"))
    print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
    acc = getattr(trainer, "accelerator", None)
    print(f"[rank {rank}] accelerator.distributed_type = "
          f"{getattr(getattr(acc, 'state', None), 'distributed_type', 'n/a')}")
    print(f"[rank {rank}] accelerator.num_processes = "
          f"{getattr(acc, 'num_processes', 'n/a')}")

    # ---------- Train ----------
    result = trainer.train()

    if is_main:
        print("\nTraining complete (FSDP).")
        print(result.metrics)


if __name__ == "__main__":
    main()

Expected behavior

Should train fine.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions