Skip to content

Commit 90aa5cb

Browse files
committed
fix: Restore Megatron dedicated merged sync
1 parent e7ebfef commit 90aa5cb

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

src/art/megatron/train.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,12 @@ def _is_art_adapter_param_name(name: str) -> bool:
13881388
def _unwrap_art_wrapper_name(name: str) -> str:
13891389
while name.startswith("module."):
13901390
name = name[len("module.") :]
1391+
while name.startswith("_orig_mod."):
1392+
name = name[len("_orig_mod.") :]
1393+
while "._orig_mod." in name:
1394+
name = name.replace("._orig_mod.", ".")
1395+
if name.endswith("._orig_mod"):
1396+
name = name[: -len("._orig_mod")]
13911397
for wrapped, unwrapped in (
13921398
(".linear_proj.linear_proj.", ".linear_proj."),
13931399
(".linear_qkv.linear_qkv.", ".linear_qkv."),
@@ -1480,24 +1486,35 @@ def _build_art_merge_handlers(
14801486
continue
14811487
if not _is_language_transformer_layer_name(module_name):
14821488
continue
1483-
prefix = f"language_model.decoder.layers.{module.layer_number - 1}"
1489+
prefixes = (
1490+
f"decoder.layers.{module.layer_number - 1}",
1491+
f"language_model.decoder.layers.{module.layer_number - 1}",
1492+
)
14841493
linear_proj = getattr(module.self_attention, "linear_proj", None)
14851494
if isinstance(linear_proj, SelfAttentionLinearProjLoRA):
1486-
exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = linear_proj
1495+
for prefix in prefixes:
1496+
exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = (
1497+
linear_proj
1498+
)
14871499
linear_qkv = getattr(module.self_attention, "linear_qkv", None)
14881500
if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA):
1489-
exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = linear_qkv
1501+
for prefix in prefixes:
1502+
exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = (
1503+
linear_qkv
1504+
)
14901505
experts = getattr(module.mlp, "experts", None)
14911506
if experts is None:
14921507
continue
14931508
if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA):
1494-
prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = (
1495-
experts.linear_fc1
1496-
)
1509+
for prefix in prefixes:
1510+
prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = (
1511+
experts.linear_fc1
1512+
)
14971513
if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA):
1498-
prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = (
1499-
experts.linear_fc2
1500-
)
1514+
for prefix in prefixes:
1515+
prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = (
1516+
experts.linear_fc2
1517+
)
15011518
return exact_handlers, prefix_handlers
15021519

15031520

@@ -1542,21 +1559,15 @@ def _merge_art_lora_into_hf_weights(
15421559
return converted_weights_dict
15431560
if isinstance(handler, MLPExpertsLinearFC1LoRA):
15441561
for hf_name, base_weight in list(converted_weights_dict.items()):
1545-
delta = (
1546-
torch.cat(
1547-
[
1548-
_lora_delta(
1549-
handler.gate_lora, _expert_index_from_hf_name(hf_name)
1550-
),
1551-
_lora_delta(
1552-
handler.up_lora, _expert_index_from_hf_name(hf_name)
1553-
),
1554-
],
1555-
dim=0,
1556-
)
1557-
if _hf_name_has_indexed_expert(hf_name)
1558-
else _stack_moe_fc1_deltas(handler)
1559-
)
1562+
if _hf_name_has_indexed_expert(hf_name):
1563+
expert_idx = _expert_index_from_hf_name(hf_name)
1564+
if ".gate_proj." in hf_name:
1565+
delta = _lora_delta(handler.gate_lora, expert_idx)
1566+
else:
1567+
assert ".up_proj." in hf_name, hf_name
1568+
delta = _lora_delta(handler.up_lora, expert_idx)
1569+
else:
1570+
delta = _stack_moe_fc1_deltas(handler)
15601571
converted_weights_dict[hf_name] = _merge_delta_into_weight(
15611572
hf_name,
15621573
base_weight,

0 commit comments

Comments
 (0)