Skip to content

enusre_weight_tying = True does not respect the model config tie_word_embeddings flag #2944

@romitjain

Description

@romitjain

System Info

python3.12
peft==0.18.0

Who can help?

@BenjaminBossan @githubnemo

Reproduction

from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM

model_name = "Isotonic/TinyMixtral-4x248M-MoE"
device = "cuda:0"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
print(f"Is weight tying enabled?: {model.config.tie_word_embeddings}")

if not model.config.tie_word_embeddings:
    emb = model.get_input_embeddings()
    lm = model.get_output_embeddings()
    assert emb.weight.data_ptr() != lm.weight.data_ptr()
    print("Loaded model does not have tied embeddings")

modules_to_save = ["embed_tokens"]
target_modules = ["q_proj"]

lora_cfg = LoraConfig(
    modules_to_save=modules_to_save,
    target_modules=target_modules,
    task_type="CAUSAL_LM",
    ensure_weight_tying=True
)

model = get_peft_model(model, lora_cfg)

if not model.config.tie_word_embeddings:
    emb = model.get_input_embeddings()
    lm = model.get_output_embeddings()
    assert emb.weight.data_ptr() != lm.weight.data_ptr(), "PEFT model has tied embeddings"
    print("PEFT model does not have tied embeddings")

Expected behavior

The bug likely originates from

def _get_module_names_tied_with_embedding(model) -> list[str]:
not respecting the model config and taking the architecture's _tied_weights_keys attribute.

I would expect it to respect the model config too.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions