support running internlm xcomposer2 on gpu and add sdp optimization (#11115)
This commit is contained in:
		
							parent
							
								
									c5e8b90c8d
								
							
						
					
					
						commit
						37b98a531f
					
				
					 2 changed files with 106 additions and 17 deletions
				
			
		| 
						 | 
					@ -44,11 +44,11 @@ import transformers
 | 
				
			||||||
import importlib.util
 | 
					import importlib.util
 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
 | 
				
			||||||
from .utils import logger, get_cur_qtype_and_imatrix
 | 
					from .utils import logger, get_cur_qtype_and_imatrix
 | 
				
			||||||
from typing import Union
 | 
					 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from typing import List, Optional, Tuple, Union
 | 
					from typing import List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					from types import MethodType
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1228,6 +1228,8 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
 | 
					        convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
 | 
				
			||||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
 | 
					        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
 | 
				
			||||||
        convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
 | 
					        convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
 | 
				
			||||||
 | 
					        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
				
			||||||
    elif model.config.model_type == "qwen":
 | 
					    elif model.config.model_type == "qwen":
 | 
				
			||||||
        if hasattr(model.config, "visual"):
 | 
					        if hasattr(model.config, "visual"):
 | 
				
			||||||
            # for Qwen-VL-Chat
 | 
					            # for Qwen-VL-Chat
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,7 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
""" PyTorch InternLM model."""
 | 
					""" PyTorch InternLM model."""
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
| 
						 | 
					@ -47,9 +47,13 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import update_past_key_value
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
				
			||||||
from einops import rearrange
 | 
					from einops import rearrange
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -347,6 +351,7 @@ def internlm_xcomposser2_attention_forward(
 | 
				
			||||||
    **kwargs,
 | 
					    **kwargs,
 | 
				
			||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    qkv_states = self.wqkv(hidden_states)
 | 
					    qkv_states = self.wqkv(hidden_states)
 | 
				
			||||||
    qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
 | 
					    qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
 | 
				
			||||||
| 
						 | 
					@ -375,26 +380,45 @@ def internlm_xcomposser2_attention_forward(
 | 
				
			||||||
    query_states, key_states = apply_rotary_pos_emb(
 | 
					    query_states, key_states = apply_rotary_pos_emb(
 | 
				
			||||||
        query_states, key_states, cos, sin, position_ids, "internlm")
 | 
					        query_states, key_states, cos, sin, position_ids, "internlm")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
 | 
				
			||||||
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
					    key_states, value_states = update_past_key_value(
 | 
				
			||||||
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
					        past_key_value, key_states, value_states,
 | 
				
			||||||
 | 
					        kv_seq_len, use_quantize_kv, device
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					    # IPEX-LLM OPT: sdp
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        if use_quantize_kv:
 | 
				
			||||||
 | 
					            attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
				
			||||||
 | 
					                                              attention_mask)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
 | 
					    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        if use_quantize_kv:
 | 
				
			||||||
 | 
					            attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if use_quantize_kv:
 | 
				
			||||||
 | 
					            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
				
			||||||
 | 
					                                                            query_states.dtype)
 | 
				
			||||||
 | 
					        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.matmul(query_states, key_states.transpose(
 | 
					        attn_weights = torch.matmul(query_states,
 | 
				
			||||||
        2, 3)) / math.sqrt(self.head_dim)
 | 
					                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if attention_mask is not None:
 | 
					        if attention_mask is not None:
 | 
				
			||||||
        attn_weights = attn_weights + attention_mask
 | 
					            attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # upcast attention to fp32
 | 
					        # upcast attention to fp32
 | 
				
			||||||
    attn_weights = nn.functional.softmax(
 | 
					        attn_weights = nn.functional.softmax(attn_weights,
 | 
				
			||||||
        attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
					                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
					        attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
| 
						 | 
					@ -423,3 +447,66 @@ def internlm_xcomposser2_mlp_forward(
 | 
				
			||||||
    w2 = self.w2(x)
 | 
					    w2 = self.w2(x)
 | 
				
			||||||
    w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
 | 
					    w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
 | 
				
			||||||
    return w2
 | 
					    return w2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def internlm_xcomposser2_chat(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    tokenizer,
 | 
				
			||||||
 | 
					    query: str,
 | 
				
			||||||
 | 
					    image: torch.Tensor = None,
 | 
				
			||||||
 | 
					    history: List[Tuple[str, str]]=[],
 | 
				
			||||||
 | 
					    streamer=None,
 | 
				
			||||||
 | 
					    max_new_tokens: int = 1024,
 | 
				
			||||||
 | 
					    do_sample: bool = True,
 | 
				
			||||||
 | 
					    temperature: float = 1.0,
 | 
				
			||||||
 | 
					    top_p: float = 0.8,
 | 
				
			||||||
 | 
					    repetition_penalty: float=1.005,
 | 
				
			||||||
 | 
					    meta_instruction:
 | 
				
			||||||
 | 
					    str = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
 | 
				
			||||||
 | 
					           '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model'
 | 
				
			||||||
 | 
					           'that is developed by Shanghai AI Laboratory (上海人工智能实验室).'
 | 
				
			||||||
 | 
					           'It is designed to be helpful, honest, and harmless.\n'
 | 
				
			||||||
 | 
					           '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the'
 | 
				
			||||||
 | 
					           'language chosen by the user such as English and 中文.\n'
 | 
				
			||||||
 | 
					           '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating'
 | 
				
			||||||
 | 
					           'responses effectively based on the provided image.'),
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    if image is None:
 | 
				
			||||||
 | 
					        inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
 | 
				
			||||||
 | 
					        im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        image = self.encode_img(image)
 | 
				
			||||||
 | 
					        inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
 | 
				
			||||||
 | 
					                                                   history, meta_instruction)
 | 
				
			||||||
 | 
					    inputs = {
 | 
				
			||||||
 | 
					        k: v.to(self.device)
 | 
				
			||||||
 | 
					        for k, v in inputs.items() if torch.is_tensor(v)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    im_mask = im_mask.to(self.device)
 | 
				
			||||||
 | 
					    # also add end-of-assistant token in eos token id to avoid unnecessary generation
 | 
				
			||||||
 | 
					    eos_token_id = [
 | 
				
			||||||
 | 
					        tokenizer.eos_token_id,
 | 
				
			||||||
 | 
					        tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    outputs = self.generate(
 | 
				
			||||||
 | 
					        **inputs,
 | 
				
			||||||
 | 
					        streamer=streamer,
 | 
				
			||||||
 | 
					        max_new_tokens=max_new_tokens,
 | 
				
			||||||
 | 
					        do_sample=do_sample,
 | 
				
			||||||
 | 
					        temperature=temperature,
 | 
				
			||||||
 | 
					        top_p=top_p,
 | 
				
			||||||
 | 
					        eos_token_id=eos_token_id,
 | 
				
			||||||
 | 
					        repetition_penalty=repetition_penalty,
 | 
				
			||||||
 | 
					        im_mask=im_mask,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if image is None:
 | 
				
			||||||
 | 
					        outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        outputs = outputs[0].cpu().tolist()
 | 
				
			||||||
 | 
					    response = tokenizer.decode(outputs, skip_special_tokens=True)
 | 
				
			||||||
 | 
					    response = response.split('[UNUSED_TOKEN_145]')[0]
 | 
				
			||||||
 | 
					    history = history + [(query, response)]
 | 
				
			||||||
 | 
					    return response, history
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue