Chatglm2 rope optimization on xpu (#9350)
This commit is contained in:
		
							parent
							
								
									833e4dbc8d
								
							
						
					
					
						commit
						1420e45cc0
					
				
					 2 changed files with 96 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -284,6 +284,7 @@ def _optimize_post(model):
 | 
			
		|||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_model_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.SelfAttention,
 | 
			
		||||
                            chatglm2_attention_forward_8eb45c
 | 
			
		||||
| 
						 | 
				
			
			@ -291,6 +292,9 @@ def _optimize_post(model):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.CoreAttention,
 | 
			
		||||
                            core_attn_forward_8eb45c)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.ChatGLMModel,
 | 
			
		||||
                            chatglm2_model_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,8 +18,9 @@
 | 
			
		|||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
			
		||||
from typing import Optional, Tuple, List
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +55,7 @@ def split_tensor_along_last_dim(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    # x: [sq, b, np, hn]
 | 
			
		||||
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
			
		||||
    rot_dim = rope_cache.shape[-2] * 2
 | 
			
		||||
| 
						 | 
				
			
			@ -87,6 +88,77 @@ def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		|||
    return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids,
 | 
			
		||||
        position_ids: Optional[torch.Tensor]=None,
 | 
			
		||||
        attention_mask: Optional[torch.BoolTensor]=None,
 | 
			
		||||
        full_attention_mask: Optional[torch.BoolTensor]=None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor]=None,
 | 
			
		||||
        use_cache: Optional[bool]=None,
 | 
			
		||||
        output_hidden_states: Optional[bool]=None,
 | 
			
		||||
        return_dict: Optional[bool]=None,
 | 
			
		||||
):
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None
 | 
			
		||||
        else self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    batch_size, seq_length = input_ids.shape
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embedding(input_ids)
 | 
			
		||||
 | 
			
		||||
    if full_attention_mask is None:
 | 
			
		||||
        if (attention_mask is not None and not attention_mask.all()) or (
 | 
			
		||||
                past_key_values and seq_length != 1):
 | 
			
		||||
            full_attention_mask = self.get_masks(input_ids,
 | 
			
		||||
                                                 past_key_values,
 | 
			
		||||
                                                 padding_mask=attention_mask)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = input_ids.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and not self.training
 | 
			
		||||
 | 
			
		||||
    # Rotary positional embeddings
 | 
			
		||||
    rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 | 
			
		||||
    if position_ids is not None:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb[position_ids]
 | 
			
		||||
    else:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb[None, :seq_length]
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        # Repeat cos sin here, call only once for each token.
 | 
			
		||||
        # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
 | 
			
		||||
        # If put this to attension forward, it will generate too many times.
 | 
			
		||||
        cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
 | 
			
		||||
        cos = cos.squeeze(-1)
 | 
			
		||||
        sin = sin.squeeze(-1)
 | 
			
		||||
        cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
 | 
			
		||||
        sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
 | 
			
		||||
        rotary_pos_emb = (cos, sin)
 | 
			
		||||
    else:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
    # Run encoder.
 | 
			
		||||
    hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
 | 
			
		||||
        inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
 | 
			
		||||
        kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
 | 
			
		||||
                     if v is not None)
 | 
			
		||||
 | 
			
		||||
    return BaseModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=presents,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attentions,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_attention_forward_8eb45c(
 | 
			
		||||
        self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -132,12 +204,26 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
 | 
			
		||||
        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
 | 
			
		||||
 | 
			
		||||
    cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
 | 
			
		||||
 | 
			
		||||
    # apply relative positional encoding (rotary embedding)
 | 
			
		||||
    if rotary_pos_emb is not None:
 | 
			
		||||
        query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
 | 
			
		||||
        key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
 | 
			
		||||
 | 
			
		||||
    cur_length, batch_size = query_layer.shape[0], query_layer.shape[1]
 | 
			
		||||
        if len(rotary_pos_emb) == 2:  # use_fuse_rope, see chatglm2_model_forward
 | 
			
		||||
            cos, sin = rotary_pos_emb
 | 
			
		||||
            rot_dim = cos.shape[-1]
 | 
			
		||||
            query_layer = query_layer.transpose(0, 1)
 | 
			
		||||
            key_layer = key_layer.transpose(0, 1)
 | 
			
		||||
            query_layer_cur = query_layer[..., :rot_dim]
 | 
			
		||||
            key_layer_cur = key_layer[..., :rot_dim]
 | 
			
		||||
            # ipex's apply_rotary_embedding can change the origin storage, so query_layer will get
 | 
			
		||||
            # the result directly.
 | 
			
		||||
            torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
 | 
			
		||||
            torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
 | 
			
		||||
            query_layer = query_layer.transpose(0, 1)
 | 
			
		||||
            key_layer = key_layer.transpose(0, 1)
 | 
			
		||||
        else:
 | 
			
		||||
            query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb)
 | 
			
		||||
            key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb)
 | 
			
		||||
 | 
			
		||||
    if self.multi_query_attention:
 | 
			
		||||
        key_length = key_layer.size(0)
 | 
			
		||||
| 
						 | 
				
			
			@ -200,7 +286,6 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
    # ==================================
 | 
			
		||||
    # core attention computation
 | 
			
		||||
    # ==================================
 | 
			
		||||
 | 
			
		||||
    context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
 | 
			
		||||
 | 
			
		||||
    # =================
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue