diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index bdca4092..69e005af 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -44,11 +44,11 @@ import transformers import importlib.util from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from .utils import logger, get_cur_qtype_and_imatrix -from typing import Union import numpy as np import os from ipex_llm.utils.common import invalidInputError from typing import List, Optional, Tuple, Union +from types import MethodType import subprocess import sys @@ -1228,6 +1228,8 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward) from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward) + from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat + model.chat = MethodType(internlm_xcomposser2_chat, model) elif model.config.model_type == "qwen": if hasattr(model.config, "visual"): # for Qwen-VL-Chat diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 4b19faf1..ab5a5864 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -37,7 +37,7 @@ # limitations under the License. """ PyTorch InternLM model.""" import math -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch import torch.utils.checkpoint @@ -47,9 +47,13 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import update_past_key_value +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from einops import rearrange import os + KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -347,6 +351,7 @@ def internlm_xcomposser2_attention_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + device = hidden_states.device qkv_states = self.wqkv(hidden_states) qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling, @@ -375,26 +380,45 @@ def internlm_xcomposser2_attention_forward( query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, "internlm") - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - + # IPEX-LLM OPT: kv cache and quantzie kv cache + use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) past_key_value = (key_states, value_states) if use_cache else None - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # IPEX-LLM OPT: sdp + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import linear_q4_0 + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import linear_q4_0 + if use_quantize_kv: + attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) + else: + attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -423,3 +447,66 @@ def internlm_xcomposser2_mlp_forward( w2 = self.w2(x) w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B) return w2 + + +@torch.no_grad() +def internlm_xcomposser2_chat( + self, + tokenizer, + query: str, + image: torch.Tensor = None, + history: List[Tuple[str, str]]=[], + streamer=None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 1.0, + top_p: float = 0.8, + repetition_penalty: float=1.005, + meta_instruction: + str = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n' + '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model' + 'that is developed by Shanghai AI Laboratory (上海人工智能实验室).' + 'It is designed to be helpful, honest, and harmless.\n' + '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the' + 'language chosen by the user such as English and 中文.\n' + '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating' + 'responses effectively based on the provided image.'), + **kwargs, +): + if image is None: + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() + else: + image = self.encode_img(image) + inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image, + history, meta_instruction) + inputs = { + k: v.to(self.device) + for k, v in inputs.items() if torch.is_tensor(v) + } + im_mask = im_mask.to(self.device) + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] + ] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + im_mask=im_mask, + **kwargs, + ) + if image is None: + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + else: + outputs = outputs[0].cpu().tolist() + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('[UNUSED_TOKEN_145]')[0] + history = history + [(query, response)] + return response, history