ParamWrapper.get_delta_weight doesn't account for FSDP-sharded lora_A/lora_B parameters. The fix would likely involve either calling torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params() around the weight access, or restructuring the delta computation to work within the FSDP all-gather lifecycle.
[rank6]: File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2228, in forward
[rank6]: with self._activate_lora(self.active_adapters):
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File ".local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 137, in __enter__
[rank6]: return next(self.gen)
[rank6]: ^^^^^^^^^^^^^^
[rank6]: File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2104, in _activate_lora
[rank6]: delta_weight = self.get_delta_weight(active_adapter)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File ".venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 2081, in get_delta_weight
[rank6]: weight_A = weight_A.reshape(self.num_experts, -1, weight_A.shape[-1])
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: RuntimeError: shape '[128, -1, 368640]' is invalid for input of size 368640
pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
pip install datasets accelerate
pip install hf_transfer
pip install "datasets==3.2.0" --force-reinstall
#!/usr/bin/env python3
"""
Fine-tune OpenAI GPT-OSS 120B with TRL SFTTrainer + FSDP + LoRA on 8 H100 GPUs.
Usage:
torchrun --nproc_per_node=8 train_gpt_oss_120b.py
"""
import os
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
def infer_transformer_blocks_for_fsdp(model):
"""Auto-detect transformer block class names for FSDP wrapping."""
COMMON = {
"LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
"Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
"GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
"DecoderLayer", "GPTJBlock", "OPTDecoderLayer",
}
hits = set()
for _, m in model.named_modules():
name = m.__class__.__name__
if name in COMMON:
hits.add(name)
# Fallback: grab anything that looks like a decoder block
if not hits:
for _, m in model.named_modules():
name = m.__class__.__name__
if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
hits.add(name)
return sorted(hits)
def main():
# ---------- DDP / CUDA binding ----------
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("NCCL_DEBUG", "WARN")
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
# ---------- Config ----------
MODEL_NAME = "openai/gpt-oss-120b"
MAX_LENGTH = 2048
PER_DEVICE_BATCH = 1
GRAD_ACCUM = 4
LR = 1.5e-4
EPOCHS = 1
OUTPUT_DIR = "/tmp/gpt-oss-120b-finetune"
is_main = int(os.environ.get("RANK", "0")) == 0
world_size = int(os.environ.get("WORLD_SIZE", "1"))
if is_main:
print("=" * 60)
print("FSDP (full_shard) launch for 120B")
print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
print("=" * 60)
# ---------- Tokenizer ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_LENGTH
tokenizer.truncation_side = "right"
# ---------- Model ----------
# No device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
use_cache=False,
low_cpu_mem_usage=True,
)
# ---------- LoRA ----------
#peft_config = LoraConfig(
# r=16,
# lora_alpha=16,
# target_modules="all-linear",
# lora_dropout=0.05,
# bias="none",
# task_type="CAUSAL_LM",
#)
peft_config = LoraConfig(
r=32,
lora_alpha=32,
target_modules="all-linear",
rank_pattern={
"mlp.experts.gate_up_proj": 8,
"mlp.experts.down_proj": 8
},
target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
if is_main:
model.print_trainable_parameters()
# ---------- Data ----------
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
if is_main:
print(f"Dataset size: {len(dataset)}")
# ---------- FSDP settings ----------
fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
if not fsdp_wrap_classes:
raise RuntimeError(
"Could not infer transformer block classes for FSDP wrapping; "
"print(model) and add the block class explicitly."
)
if is_main:
print(f"FSDP wrapping classes: {fsdp_wrap_classes}")
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LR,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
logging_steps=5,
logging_strategy="steps",
save_strategy="no",
report_to="none",
ddp_find_unused_parameters=False,
dataloader_pin_memory=True,
max_length=MAX_LENGTH,
gradient_checkpointing=False,
# ---- FSDP knobs ----
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
"activation_checkpointing": True,
"activation_checkpointing_reentrant": False,
"xla": False,
"limit_all_gathers": True,
"use_orig_params": True,
"sync_module_states": True,
},
)
# ---------- Trainer ----------
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
# Verify distributed init & FSDP
rank = int(os.getenv("RANK", "0"))
print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
acc = getattr(trainer, "accelerator", None)
print(f"[rank {rank}] accelerator.distributed_type = "
f"{getattr(getattr(acc, 'state', None), 'distributed_type', 'n/a')}")
print(f"[rank {rank}] accelerator.num_processes = "
f"{getattr(acc, 'num_processes', 'n/a')}")
# ---------- Train ----------
result = trainer.train()
if is_main:
print("\nTraining complete (FSDP).")
print(result.metrics)
if __name__ == "__main__":
main()
Should train fine.
System Info
ParamWrapper.get_delta_weight doesn't account for FSDP-sharded lora_A/lora_B parameters. The fix would likely involve either calling torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params() around the weight access, or restructuring the delta computation to work within the FSDP all-gather lifecycle.
Error:
Who can help?
@BenjaminBossan because #2638
Reproduction
On 8x H100
script to repro:
Expected behavior
Should train fine.