@@ -228,19 +228,53 @@ def load_experts_weight(
228228 if is_ffn_merged :
229229 for i in range (self .num_local_experts ):
230230 expert_idx = self .expert_id_offset + i
231+ down_proj_expert_weight_key_name = down_proj_expert_weight_key .format (expert_idx )
232+ up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key .format (expert_idx )
231233 up_gate_proj_weights .append (
232- get_tensor (state_dict .pop (up_gate_proj_expert_weight_key .format (expert_idx )))
234+ get_tensor (
235+ state_dict .pop (up_gate_proj_expert_weight_key_name )
236+ if up_gate_proj_expert_weight_key_name in state_dict
237+ else up_gate_proj_expert_weight_key_name ,
238+ self .fd_config .parallel_config .model_name_or_path ,
239+ )
240+ )
241+ down_proj_weights .append (
242+ get_tensor (
243+ state_dict .pop (down_proj_expert_weight_key_name )
244+ if down_proj_expert_weight_key_name in state_dict
245+ else down_proj_expert_weight_key_name ,
246+ self .fd_config .parallel_config .model_name_or_path ,
247+ )
233248 )
234- down_proj_weights .append (get_tensor (state_dict .pop (down_proj_expert_weight_key .format (expert_idx ))))
235249 else :
236250 gate_expert_weight_key = up_gate_proj_expert_weight_key .replace ("up_gate_proj" , "gate_proj" )
237251 up_expert_weight_key = up_gate_proj_expert_weight_key .replace ("up_gate_proj" , "up_proj" )
238252 for j in range (self .num_local_experts ):
239253 expert_idx = self .expert_id_offset + j
240- gate = get_tensor (state_dict .pop (gate_expert_weight_key .format (expert_idx )))
241- up = get_tensor (state_dict .pop (up_expert_weight_key .format (expert_idx )))
254+ gate_expert_weight_key_name = gate_expert_weight_key .format (expert_idx )
255+ up_expert_weight_key_name = up_expert_weight_key .format (expert_idx )
256+ down_proj_expert_weight_key_name = down_proj_expert_weight_key .format (expert_idx )
257+ gate = get_tensor (
258+ state_dict .pop (gate_expert_weight_key_name )
259+ if gate_expert_weight_key_name in state_dict
260+ else gate_expert_weight_key_name ,
261+ self .fd_config .parallel_config .model_name_or_path ,
262+ )
263+ up = get_tensor (
264+ state_dict .pop (up_expert_weight_key_name )
265+ if up_expert_weight_key_name in state_dict
266+ else up_expert_weight_key_name ,
267+ self .fd_config .parallel_config .model_name_or_path ,
268+ )
242269 up_gate_proj_weights .append (paddle .concat ([gate , up ], axis = - 1 ))
243- down_proj_weights .append (get_tensor (state_dict .pop (down_proj_expert_weight_key .format (expert_idx ))))
270+ down_proj_weights .append (
271+ get_tensor (
272+ state_dict .pop (down_proj_expert_weight_key_name )
273+ if down_proj_expert_weight_key_name in state_dict
274+ else down_proj_expert_weight_key_name ,
275+ self .fd_config .parallel_config .model_name_or_path ,
276+ )
277+ )
244278 return up_gate_proj_weights , down_proj_weights
245279
246280 def extract_moe_ffn_weights (self , state_dict : dict ):
0 commit comments