fix chatglm3 npu output (#11590)
This commit is contained in:
parent
06930ab258
commit
5837bc0014
1 changed files with 24 additions and 27 deletions
|
|
@ -64,7 +64,16 @@ def chatglm2_model_forward(
|
||||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
else:
|
else:
|
||||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
# ipex-llm change start: change rope cache shape
|
||||||
|
# rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
|
||||||
|
cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0)
|
||||||
|
cos = cos.squeeze(0).unsqueeze(1)
|
||||||
|
sin = sin.squeeze(0).unsqueeze(1)
|
||||||
|
cos = cos.repeat_interleave(2, dim=-1)
|
||||||
|
sin = sin.repeat_interleave(2, dim=-1)
|
||||||
|
# cos, sin: [bsz, 1, seq_len, rot_dim]
|
||||||
|
rotary_pos_emb = (cos, sin)
|
||||||
|
# ipex-llm change end
|
||||||
|
|
||||||
# ipex-llm changes begin:
|
# ipex-llm changes begin:
|
||||||
# generate `causal_mask` and replace `full_attention_mask` with it
|
# generate `causal_mask` and replace `full_attention_mask` with it
|
||||||
|
|
@ -76,14 +85,6 @@ def chatglm2_model_forward(
|
||||||
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||||
mask_value = torch.finfo(inputs_embeds.dtype).min
|
mask_value = torch.finfo(inputs_embeds.dtype).min
|
||||||
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
||||||
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
|
|
||||||
full_attention_mask = self.get_masks(input_ids,
|
|
||||||
past_key_values,
|
|
||||||
padding_mask=attention_mask)
|
|
||||||
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
|
|
||||||
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
|
||||||
mask_value = torch.finfo(inputs_embeds.dtype).min
|
|
||||||
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
|
|
||||||
|
|
@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
def rotate_every_two(x: torch.Tensor):
|
||||||
# x: [sq, b, np, hn]
|
x1 = x[:, :, :, ::2]
|
||||||
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
x2 = x[:, :, :, 1::2]
|
||||||
rot_dim = rope_cache.shape[-2] * 2
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# x: [bsz, n_head, seq_len, head_dim]
|
||||||
|
cos, sin = rope_cache
|
||||||
|
rot_dim = cos.size(-1)
|
||||||
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||||||
# truncate to support variable sizes
|
x_out = x * cos + rotate_every_two(x) * sin
|
||||||
rope_cache = rope_cache[:sq]
|
return torch.cat([x_out, x_pass], dim=-1)
|
||||||
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
|
||||||
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
|
||||||
x_out2 = torch.stack(
|
|
||||||
[
|
|
||||||
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
|
||||||
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
|
||||||
],
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
x_out2 = x_out2.flatten(3)
|
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm2_attention_forward(
|
def chatglm2_attention_forward(
|
||||||
|
|
@ -246,7 +243,7 @@ def chatglm2_attention_forward(
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=attention_mask,
|
attn_mask=attention_mask,
|
||||||
is_causal=q_len > 1 and bsz == 1,
|
is_causal=attention_mask is None and q_len > 1 and bsz == 1,
|
||||||
)
|
)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue