@@ -1388,6 +1388,12 @@ def _is_art_adapter_param_name(name: str) -> bool:
13881388def _unwrap_art_wrapper_name (name : str ) -> str :
13891389 while name .startswith ("module." ):
13901390 name = name [len ("module." ) :]
1391+ while name .startswith ("_orig_mod." ):
1392+ name = name [len ("_orig_mod." ) :]
1393+ while "._orig_mod." in name :
1394+ name = name .replace ("._orig_mod." , "." )
1395+ if name .endswith ("._orig_mod" ):
1396+ name = name [: - len ("._orig_mod" )]
13911397 for wrapped , unwrapped in (
13921398 (".linear_proj.linear_proj." , ".linear_proj." ),
13931399 (".linear_qkv.linear_qkv." , ".linear_qkv." ),
@@ -1480,24 +1486,35 @@ def _build_art_merge_handlers(
14801486 continue
14811487 if not _is_language_transformer_layer_name (module_name ):
14821488 continue
1483- prefix = f"language_model.decoder.layers.{ module .layer_number - 1 } "
1489+ prefixes = (
1490+ f"decoder.layers.{ module .layer_number - 1 } " ,
1491+ f"language_model.decoder.layers.{ module .layer_number - 1 } " ,
1492+ )
14841493 linear_proj = getattr (module .self_attention , "linear_proj" , None )
14851494 if isinstance (linear_proj , SelfAttentionLinearProjLoRA ):
1486- exact_handlers [f"{ prefix } .self_attention.linear_proj.weight" ] = linear_proj
1495+ for prefix in prefixes :
1496+ exact_handlers [f"{ prefix } .self_attention.linear_proj.weight" ] = (
1497+ linear_proj
1498+ )
14871499 linear_qkv = getattr (module .self_attention , "linear_qkv" , None )
14881500 if isinstance (linear_qkv , SelfAttentionLinearQKVLoRA ):
1489- exact_handlers [f"{ prefix } .self_attention.linear_qkv.weight" ] = linear_qkv
1501+ for prefix in prefixes :
1502+ exact_handlers [f"{ prefix } .self_attention.linear_qkv.weight" ] = (
1503+ linear_qkv
1504+ )
14901505 experts = getattr (module .mlp , "experts" , None )
14911506 if experts is None :
14921507 continue
14931508 if isinstance (experts .linear_fc1 , MLPExpertsLinearFC1LoRA ):
1494- prefix_handlers [f"{ prefix } .mlp.experts.linear_fc1.weight" ] = (
1495- experts .linear_fc1
1496- )
1509+ for prefix in prefixes :
1510+ prefix_handlers [f"{ prefix } .mlp.experts.linear_fc1.weight" ] = (
1511+ experts .linear_fc1
1512+ )
14971513 if isinstance (experts .linear_fc2 , MLPExpertsLinearFC2LoRA ):
1498- prefix_handlers [f"{ prefix } .mlp.experts.linear_fc2.weight" ] = (
1499- experts .linear_fc2
1500- )
1514+ for prefix in prefixes :
1515+ prefix_handlers [f"{ prefix } .mlp.experts.linear_fc2.weight" ] = (
1516+ experts .linear_fc2
1517+ )
15011518 return exact_handlers , prefix_handlers
15021519
15031520
@@ -1542,21 +1559,15 @@ def _merge_art_lora_into_hf_weights(
15421559 return converted_weights_dict
15431560 if isinstance (handler , MLPExpertsLinearFC1LoRA ):
15441561 for hf_name , base_weight in list (converted_weights_dict .items ()):
1545- delta = (
1546- torch .cat (
1547- [
1548- _lora_delta (
1549- handler .gate_lora , _expert_index_from_hf_name (hf_name )
1550- ),
1551- _lora_delta (
1552- handler .up_lora , _expert_index_from_hf_name (hf_name )
1553- ),
1554- ],
1555- dim = 0 ,
1556- )
1557- if _hf_name_has_indexed_expert (hf_name )
1558- else _stack_moe_fc1_deltas (handler )
1559- )
1562+ if _hf_name_has_indexed_expert (hf_name ):
1563+ expert_idx = _expert_index_from_hf_name (hf_name )
1564+ if ".gate_proj." in hf_name :
1565+ delta = _lora_delta (handler .gate_lora , expert_idx )
1566+ else :
1567+ assert ".up_proj." in hf_name , hf_name
1568+ delta = _lora_delta (handler .up_lora , expert_idx )
1569+ else :
1570+ delta = _stack_moe_fc1_deltas (handler )
15601571 converted_weights_dict [hf_name ] = _merge_delta_into_weight (
15611572 hf_name ,
15621573 base_weight ,
0 commit comments