Skip to content

Commit bfeb664

Browse files
authored
1 parent 85a78d6 commit bfeb664

3 files changed

Lines changed: 58 additions & 29 deletions

File tree

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,53 @@ def load_experts_weight(
228228
if is_ffn_merged:
229229
for i in range(self.num_local_experts):
230230
expert_idx = self.expert_id_offset + i
231+
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
232+
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
231233
up_gate_proj_weights.append(
232-
get_tensor(state_dict.pop(up_gate_proj_expert_weight_key.format(expert_idx)))
234+
get_tensor(
235+
state_dict.pop(up_gate_proj_expert_weight_key_name)
236+
if up_gate_proj_expert_weight_key_name in state_dict
237+
else up_gate_proj_expert_weight_key_name,
238+
self.fd_config.parallel_config.model_name_or_path,
239+
)
240+
)
241+
down_proj_weights.append(
242+
get_tensor(
243+
state_dict.pop(down_proj_expert_weight_key_name)
244+
if down_proj_expert_weight_key_name in state_dict
245+
else down_proj_expert_weight_key_name,
246+
self.fd_config.parallel_config.model_name_or_path,
247+
)
233248
)
234-
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
235249
else:
236250
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
237251
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
238252
for j in range(self.num_local_experts):
239253
expert_idx = self.expert_id_offset + j
240-
gate = get_tensor(state_dict.pop(gate_expert_weight_key.format(expert_idx)))
241-
up = get_tensor(state_dict.pop(up_expert_weight_key.format(expert_idx)))
254+
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
255+
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
256+
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
257+
gate = get_tensor(
258+
state_dict.pop(gate_expert_weight_key_name)
259+
if gate_expert_weight_key_name in state_dict
260+
else gate_expert_weight_key_name,
261+
self.fd_config.parallel_config.model_name_or_path,
262+
)
263+
up = get_tensor(
264+
state_dict.pop(up_expert_weight_key_name)
265+
if up_expert_weight_key_name in state_dict
266+
else up_expert_weight_key_name,
267+
self.fd_config.parallel_config.model_name_or_path,
268+
)
242269
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
243-
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
270+
down_proj_weights.append(
271+
get_tensor(
272+
state_dict.pop(down_proj_expert_weight_key_name)
273+
if down_proj_expert_weight_key_name in state_dict
274+
else down_proj_expert_weight_key_name,
275+
self.fd_config.parallel_config.model_name_or_path,
276+
)
277+
)
244278
return up_gate_proj_weights, down_proj_weights
245279

246280
def extract_moe_ffn_weights(self, state_dict: dict):

fastdeploy/model_executor/layers/utils.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
"And ensure the Paddle version supports FastDeploy's custom operators"
3838
)
3939

40-
import re
4140

4241
from fastdeploy import envs
4342

@@ -107,7 +106,7 @@ def _set_var_distributed(var: Tensor, split_axis: int):
107106
main_block._find_var_recursive(var.name).is_distributed = True
108107

109108

110-
def get_tensor(input: Union[paddle.Tensor, np.ndarray, str]) -> paddle.Tensor:
109+
def get_tensor(input: Union[paddle.Tensor, np.ndarray, str], model_path=None) -> paddle.Tensor:
111110
"""
112111
Return a corresponding PaddlePaddle tensor based on the type and content of the input.
113112
@@ -125,28 +124,9 @@ def get_tensor(input: Union[paddle.Tensor, np.ndarray, str]) -> paddle.Tensor:
125124
elif isinstance(input, np.ndarray):
126125
return paddle.to_tensor(input)
127126
elif isinstance(input, str):
128-
if ".safetensors" in input:
129-
match = re.match(r"\[(.*?)\](.*)", input)
130-
if match:
131-
key_name = match.group(1)
132-
model_path = match.group(2)
133-
from safetensors import safe_open
134-
135-
with safe_open(model_path, framework="np", device="cpu") as f:
136-
if key_name in f.keys():
137-
weight = f.get_tensor(key_name)
138-
weight = paddle.Tensor(weight, zero_copy=True)
139-
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
140-
return weight
141-
else:
142-
return None
143-
else:
144-
if cache_params != "none":
145-
tmp_key = input.split("/")[-1]
146-
if tmp_key in c8_state_dict:
147-
print(f"Loading {tmp_key} in extra C8_state_dict")
148-
return paddle.to_tensor(c8_state_dict.pop(tmp_key))
149-
return paddle.load(input)
127+
from fastdeploy.model_executor.load_weight_utils import load_reordered_experts
128+
129+
return load_reordered_experts(model_path, input)
150130
else:
151131
return input
152132

@@ -377,6 +357,7 @@ def create_and_set_parameter(layer: nn.Layer, name: str, tensor: paddle.Tensor):
377357
)
378358
getattr(layer, name).set_value(tensor)
379359

360+
380361
@functools.cache
381362
def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
382363
"""

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@
3232
from fastdeploy.platforms import current_platform
3333

3434

35+
def load_reordered_experts(model_path: str, key_name: str):
36+
from safetensors import safe_open
37+
38+
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
39+
weight_list = json.load(f)["weight_map"]
40+
safetensor_path = os.path.join(model_path, weight_list[key_name])
41+
with safe_open(safetensor_path, framework="np", device="cpu") as f:
42+
if key_name in f.keys():
43+
weight = f.get_tensor(key_name)
44+
weight = paddle.Tensor(weight, zero_copy=True)
45+
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
46+
return weight
47+
48+
3549
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False):
3650
"""
3751
load ep checkpoint

0 commit comments

Comments
 (0)