Skip to content

Commit

Permalink
split weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jlonge4 committed Jan 18, 2025
1 parent 9c12a46 commit 329d250
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions optimum/neuron/models/phi4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,17 @@ def load_weights(self):
else:
is_unit_scale = False

# TODO split fused qkv_proj and mlp into separate layers
# Split fused qkv_proj and mlp into separate weights
fused_attn = attn.qkv_proj.weight.clone().detach()
fused_gate_up = mlp.gate_up_proj.weight.clone().detach()
q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=0)
gate, up = torch.chunk(fused_gate_up, 2, dim=0)

### END TODO ###
new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale)
new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None)
new_layer.add_attention_query(attn.q_proj.weight.detach().T, attn.q_proj.bias.detach())
new_layer.add_attention_key(attn.k_proj.weight.detach().T, attn.k_proj.bias.detach())
new_layer.add_attention_value(attn.v_proj.weight.detach().T, attn.v_proj.bias.detach())
new_layer.add_attention_query(q_weight)
new_layer.add_attention_key(k_weight)
new_layer.add_attention_value(v_weight)
if self.neuron_config and self.neuron_config.attn_output_transposed:
new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True)
else:
Expand Down Expand Up @@ -187,10 +190,10 @@ def load_weights(self):
)
else:
new_layer.add_parameter(
mlp.gate_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True
gate, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True
)
new_layer.add_parameter(
mlp.up_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True
up, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True
)
if self.neuron_config.weight_tiling:
new_layer.add_parameter(
Expand Down

0 comments on commit 329d250

Please sign in to comment.