System Info
python3.12
peft==0.18.0
Who can help?
@BenjaminBossan @githubnemo
Reproduction
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM
model_name = "Isotonic/TinyMixtral-4x248M-MoE"
device = "cuda:0"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
print(f"Is weight tying enabled?: {model.config.tie_word_embeddings}")
if not model.config.tie_word_embeddings:
emb = model.get_input_embeddings()
lm = model.get_output_embeddings()
assert emb.weight.data_ptr() != lm.weight.data_ptr()
print("Loaded model does not have tied embeddings")
modules_to_save = ["embed_tokens"]
target_modules = ["q_proj"]
lora_cfg = LoraConfig(
modules_to_save=modules_to_save,
target_modules=target_modules,
task_type="CAUSAL_LM",
ensure_weight_tying=True
)
model = get_peft_model(model, lora_cfg)
if not model.config.tie_word_embeddings:
emb = model.get_input_embeddings()
lm = model.get_output_embeddings()
assert emb.weight.data_ptr() != lm.weight.data_ptr(), "PEFT model has tied embeddings"
print("PEFT model does not have tied embeddings")
Expected behavior
The bug likely originates from
|
def _get_module_names_tied_with_embedding(model) -> list[str]: |
not respecting the model config and taking the architecture's
_tied_weights_keys attribute.
I would expect it to respect the model config too.
System Info
python3.12
peft==0.18.0
Who can help?
@BenjaminBossan @githubnemo
Reproduction
Expected behavior
The bug likely originates from
peft/src/peft/utils/other.py
Line 1571 in f2c0668
_tied_weights_keysattribute.I would expect it to respect the model config too.