refactor internlm and internlm2 (#11274)
This commit is contained in:
		
							parent
							
								
									fac49f15e3
								
							
						
					
					
						commit
						10e480ee96
					
				
					 2 changed files with 124 additions and 160 deletions
				
			
		| 
						 | 
				
			
			@ -719,6 +719,10 @@ def _optimize_pre(model):
 | 
			
		|||
        # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    # for internlm
 | 
			
		||||
    if model.config.model_type == "internlm":
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    # for internlm-xcomposer2-vl
 | 
			
		||||
    if model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
 | 
			
		||||
| 
						 | 
				
			
			@ -1167,27 +1171,14 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_attention_forward
 | 
			
		||||
        convert_forward(model, module.InternLMAttention, internlm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.InternLMRMSNorm, llama_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "internlm2":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm2_attention_forward
 | 
			
		||||
        try:
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.InternLM2Attention,
 | 
			
		||||
                            internlm2_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
        except:
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.InternLMAttention,
 | 
			
		||||
                            internlm_attention_forward
 | 
			
		||||
                            )
 | 
			
		||||
        try:
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.InternLM2RMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward
 | 
			
		||||
                            )
 | 
			
		||||
        except:
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.InternLMRMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward
 | 
			
		||||
                            )
 | 
			
		||||
        convert_forward(model, module.InternLM2Attention, internlm2_attention_forward)
 | 
			
		||||
        convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
 | 
			
		||||
    elif model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
        modeling_module_name = model.model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,20 +42,35 @@ from typing import Optional, Tuple, List
 | 
			
		|||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
			
		||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "InternLMAttention":
 | 
			
		||||
        new_weight = torch.cat([
 | 
			
		||||
            module.q_proj.weight.data,
 | 
			
		||||
            module.k_proj.weight.data,
 | 
			
		||||
            module.v_proj.weight.data,
 | 
			
		||||
        ], dim=0)
 | 
			
		||||
        new_bias = torch.cat([
 | 
			
		||||
            module.q_proj.bias.data,
 | 
			
		||||
            module.k_proj.bias.data,
 | 
			
		||||
            module.v_proj.bias.data,
 | 
			
		||||
        ], dim=-1)
 | 
			
		||||
 | 
			
		||||
        qkv_proj = torch.nn.Linear(0, 0, bias=True)
 | 
			
		||||
        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
			
		||||
        qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
			
		||||
        qkv_proj.in_features = new_weight.size(1)
 | 
			
		||||
        qkv_proj.out_features = new_weight.size(0)
 | 
			
		||||
        module.qkv_proj = qkv_proj
 | 
			
		||||
 | 
			
		||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internlm_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -68,109 +83,69 @@ def internlm_attention_forward(
 | 
			
		|||
    use_cache: bool=False,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    query_states = self.q_proj(hidden_states) \
 | 
			
		||||
        .view(bsz, q_len, self.num_heads, self.head_dim) \
 | 
			
		||||
        .transpose(1, 2)
 | 
			
		||||
    key_states = self.k_proj(hidden_states) \
 | 
			
		||||
        .view(bsz, q_len, self.num_heads, self.head_dim) \
 | 
			
		||||
        .transpose(1, 2)
 | 
			
		||||
    value_states = self.v_proj(hidden_states) \
 | 
			
		||||
        .view(bsz, q_len, self.num_heads, self.head_dim) \
 | 
			
		||||
        .transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    qkv = self.qkv_proj(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
			
		||||
                                                        self.num_heads,
 | 
			
		||||
                                                        self.num_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    enough_kv_room = True
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "internlm")
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                       query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states,
 | 
			
		||||
            cos,
 | 
			
		||||
            sin,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            "internlm")
 | 
			
		||||
    # [bsz, nh, t, hd]
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if not enough_kv_room:
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = extend_kv_cache(
 | 
			
		||||
                bsz,
 | 
			
		||||
                self.num_heads,
 | 
			
		||||
                self.head_dim,
 | 
			
		||||
                cache_k.size(2),
 | 
			
		||||
                kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                dtype=cache_k.dtype,
 | 
			
		||||
                device=device
 | 
			
		||||
            )
 | 
			
		||||
            new_cache_k[:] = cache_k
 | 
			
		||||
            new_cache_v[:] = cache_v
 | 
			
		||||
            cache_k = new_cache_k
 | 
			
		||||
            cache_v = new_cache_v
 | 
			
		||||
 | 
			
		||||
        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
    elif use_cache:
 | 
			
		||||
        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        new_key_states, new_value_states = init_kv_cache(
 | 
			
		||||
            bsz,
 | 
			
		||||
            self.num_heads,
 | 
			
		||||
            self.head_dim,
 | 
			
		||||
            kv_seq_len,
 | 
			
		||||
            max_cache_length,
 | 
			
		||||
            dtype=key_states.dtype,
 | 
			
		||||
            device=device
 | 
			
		||||
            query_states, key_states, cos, sin, position_ids, "internlm"
 | 
			
		||||
        )
 | 
			
		||||
        new_key_states[:] = key_states
 | 
			
		||||
        new_value_states[:] = value_states
 | 
			
		||||
        key_states = new_key_states
 | 
			
		||||
        value_states = new_value_states
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
    )
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
 | 
			
		||||
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
 | 
			
		||||
            f"but is {attn_weights.size()}"
 | 
			
		||||
        )
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
			
		||||
                f"but is {attention_mask.size()}"
 | 
			
		||||
            )
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = nn.functional.softmax(attn_weights,
 | 
			
		||||
                                         dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
 | 
			
		||||
            f"but is {attn_output.size()}"
 | 
			
		||||
        )
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights,
 | 
			
		||||
                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -229,62 +204,60 @@ def internlm2_attention_forward(
 | 
			
		|||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "internlm")
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                       query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    # query_states, key_states = apply_rotary_pos_emb(query_states,
 | 
			
		||||
    #                               key_states, cos, sin, position_ids)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states,
 | 
			
		||||
            cos,
 | 
			
		||||
            sin,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            "internlm")
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
			
		||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
			
		||||
            query_states, key_states, cos, sin, position_ids, "internlm"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
    )
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
			
		||||
                                                   value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
			
		||||
                                               value_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
			
		||||
                                                            query_states.dtype)
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
 | 
			
		||||
            f"but is {attn_weights.size()}"
 | 
			
		||||
        )
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
			
		||||
                f"but is {attention_mask.size()}"
 | 
			
		||||
            )
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = nn.functional.softmax(attn_weights,
 | 
			
		||||
                                         dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
 | 
			
		||||
            f"but is {attn_output.size()}"
 | 
			
		||||
        )
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights,
 | 
			
		||||
                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue