fix chatglm3 npu output (#11590)
This commit is contained in:
		
							parent
							
								
									06930ab258
								
							
						
					
					
						commit
						5837bc0014
					
				
					 1 changed files with 24 additions and 27 deletions
				
			
		| 
						 | 
				
			
			@ -64,7 +64,16 @@ def chatglm2_model_forward(
 | 
			
		|||
        rotary_pos_emb = rotary_pos_emb[position_ids]
 | 
			
		||||
    else:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb[None, :seq_length]
 | 
			
		||||
    rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
 | 
			
		||||
    # ipex-llm change start: change rope cache shape
 | 
			
		||||
    # rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
 | 
			
		||||
    cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0)
 | 
			
		||||
    cos = cos.squeeze(0).unsqueeze(1)
 | 
			
		||||
    sin = sin.squeeze(0).unsqueeze(1)
 | 
			
		||||
    cos = cos.repeat_interleave(2, dim=-1)
 | 
			
		||||
    sin = sin.repeat_interleave(2, dim=-1)
 | 
			
		||||
    # cos, sin: [bsz, 1, seq_len, rot_dim]
 | 
			
		||||
    rotary_pos_emb = (cos, sin)
 | 
			
		||||
    # ipex-llm change end
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes begin:
 | 
			
		||||
    # generate `causal_mask` and replace `full_attention_mask` with it
 | 
			
		||||
| 
						 | 
				
			
			@ -76,14 +85,6 @@ def chatglm2_model_forward(
 | 
			
		|||
                                  dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
			
		||||
        mask_value = torch.finfo(inputs_embeds.dtype).min
 | 
			
		||||
        causal_mask.masked_fill_(full_attention_mask, mask_value)
 | 
			
		||||
    elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
 | 
			
		||||
        full_attention_mask = self.get_masks(input_ids,
 | 
			
		||||
                                             past_key_values,
 | 
			
		||||
                                             padding_mask=attention_mask)
 | 
			
		||||
        causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
 | 
			
		||||
                                  dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
			
		||||
        mask_value = torch.finfo(inputs_embeds.dtype).min
 | 
			
		||||
        causal_mask.masked_fill_(full_attention_mask, mask_value)
 | 
			
		||||
    else:
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    # x: [sq, b, np, hn]
 | 
			
		||||
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
			
		||||
    rot_dim = rope_cache.shape[-2] * 2
 | 
			
		||||
def rotate_every_two(x: torch.Tensor):
 | 
			
		||||
    x1 = x[:, :, :, ::2]
 | 
			
		||||
    x2 = x[:, :, :, 1::2]
 | 
			
		||||
    x = torch.stack((-x2, x1), dim=-1)
 | 
			
		||||
    return x.flatten(-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor:
 | 
			
		||||
    # x: [bsz, n_head, seq_len, head_dim]
 | 
			
		||||
    cos, sin = rope_cache
 | 
			
		||||
    rot_dim = cos.size(-1)
 | 
			
		||||
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
 | 
			
		||||
    # truncate to support variable sizes
 | 
			
		||||
    rope_cache = rope_cache[:sq]
 | 
			
		||||
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
 | 
			
		||||
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
 | 
			
		||||
    x_out2 = torch.stack(
 | 
			
		||||
        [
 | 
			
		||||
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
 | 
			
		||||
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
 | 
			
		||||
        ],
 | 
			
		||||
        -1,
 | 
			
		||||
    )
 | 
			
		||||
    x_out2 = x_out2.flatten(3)
 | 
			
		||||
    return torch.cat((x_out2, x_pass), dim=-1)
 | 
			
		||||
    x_out = x * cos + rotate_every_two(x) * sin
 | 
			
		||||
    return torch.cat([x_out, x_pass], dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -246,7 +243,7 @@ def chatglm2_attention_forward(
 | 
			
		|||
            key_states,
 | 
			
		||||
            value_states,
 | 
			
		||||
            attn_mask=attention_mask,
 | 
			
		||||
            is_causal=q_len > 1 and bsz == 1,
 | 
			
		||||
            is_causal=attention_mask is None and q_len > 1 and bsz == 1,
 | 
			
		||||
        )
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue