# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Some parts of this file is adapted from # https://huggingface.co/internlm/internlm-chat-7b/blob/659ed911eec1e26810f9854f19c5ec27854e9cf3/modeling_internlm.py # which is licensed under Apache License 2.0: # # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch InternLM model.""" from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import update_past_key_value from einops import rearrange def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, "InternLMAttention") def internlm_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, position_ids: Optional[torch.LongTensor]=None, past_key_value: Optional[Tuple[torch.Tensor]]=None, output_attentions: bool=False, use_cache: bool=False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.split([self.num_heads, self.num_heads, self.num_heads], dim=1) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): import xe_addons xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, "internlm" ) # IPEX-LLM OPT: kv cache and quantzie kv cache use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states, self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device ) past_key_value = (key_states, value_states) if use_cache else None # IPEX-LLM OPT: sdp attn_weights = None attn_output = scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, q_len == kv_seq_len ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def internlm2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, position_ids: Optional[torch.LongTensor]=None, past_key_value: Optional[Tuple[torch.Tensor]]=None, output_attentions: bool=False, use_cache: bool=False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) qkv_states = rearrange( qkv_states, "b q (h gs d) -> b q h gs d", gs=2 + self.num_key_value_groups, d=self.head_dim, ) query_states = qkv_states[..., : self.num_key_value_groups, :] query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): import xe_addons xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, "internlm" ) # IPEX-LLM OPT: kv cache and quantzie kv cache use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states, self.num_heads, self.num_key_value_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device ) past_key_value = (key_states, value_states) if use_cache else None # IPEX-LLM OPT: sdp attn_weights = None attn_output = scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, q_len == kv_seq_len ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def pre_process_attn_and_mlp(module: torch.nn.Module): if module.__class__.__name__ == "InternLM2Attention": module.wqkv_lora_scaling = module.wqkv.lora_scaling module.wqkv_Plora_A = module.wqkv.Plora_A module.wqkv_Plora_B = module.wqkv.Plora_B del module.wqkv.Plora_A del module.wqkv.Plora_B module.wo_lora_scaling = module.wo.lora_scaling module.wo_Plora_A = module.wo.Plora_A module.wo_Plora_B = module.wo.Plora_B del module.wo.Plora_A del module.wo.Plora_B elif module.__class__.__name__ == "InternLM2MLP": module.w1_lora_scaling = module.w1.lora_scaling module.w1_Plora_A = module.w1.Plora_A module.w1_Plora_B = module.w1.Plora_B del module.w1.Plora_A del module.w1.Plora_B module.w2_lora_scaling = module.w2.lora_scaling module.w2_Plora_A = module.w2.Plora_A module.w2_Plora_B = module.w2.Plora_B del module.w2.Plora_A del module.w2.Plora_B module.w3_lora_scaling = module.w3.lora_scaling module.w3_Plora_A = module.w3.Plora_A module.w3_Plora_B = module.w3.Plora_B del module.w3.Plora_A del module.w3.Plora_B def add_lora(x: torch.Tensor, result: torch.Tensor, im_mask: torch.Tensor = None, lora_scaling: float = 0, Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None): invalidInputError(x.dim() == 3 and result.dim() == 3, "`x` and `result` should have 3 dims") if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0: return result else: for start_idx, end_idx in im_mask: result[:, start_idx:end_idx, :] += Plora_B( Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling ) return result def internlm_xcomposser2_model_forward_wrapper(origin_forward): def internlm_xcomposser2_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ): im_mask = kwargs.get('im_mask', None) if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0: # decoding or no image input, `im_mask` is not needed kwargs['im_mask'] = [] else: # replace im_mask with start_idx and end_idx to improve performance im_mask = im_mask.cpu().flatten().tolist() length = len(im_mask) new_mask = [] i = 0 while i < length: while i < length and not im_mask[i]: i = i + 1 start_idx = i while i < length and im_mask[i]: i = i + 1 end_idx = i if start_idx != end_idx: new_mask.append((start_idx, end_idx)) kwargs['im_mask'] = new_mask return origin_forward( self=self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs ) return internlm_xcomposser2_model_forward def internlm_xcomposser2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, im_mask: Optional[Tuple[torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() device = hidden_states.device qkv_states = self.wqkv(hidden_states) qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling, self.wqkv_Plora_A, self.wqkv_Plora_B) qkv_states = rearrange( qkv_states, 'b q (h gs d) -> b q h gs d', gs=2 + self.num_key_value_groups, d=self.head_dim, ) query_states = qkv_states[..., :self.num_key_value_groups, :] query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] # IPEX-LLM OPT: fuse rope if should_use_fuse_rope(hidden_states, position_ids, self.training): # This fuse rope will get wrong result if context_length > max_position_embeddings (32768) # we assume context_length <= 32768 import xe_addons xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, "internlm") # IPEX-LLM OPT: kv cache and quantzie kv cache use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states, self.num_heads, self.num_key_value_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device ) past_key_value = (key_states, value_states) if use_cache else None # IPEX-LLM OPT: sdp attn_weights = None attn_output = scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, q_len == kv_seq_len ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output_2 = self.wo(attn_output) attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling, self.wo_Plora_A, self.wo_Plora_B) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def internlm_xcomposser2_mlp_forward( self, x: torch.Tensor, im_mask: Optional[Tuple[torch.Tensor]] = None, ): w1 = self.w1(x) w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B) w3 = self.w3(x) w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B) x = self.act_fn(w1) * w3 w2 = self.w2(x) w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B) return w2 @torch.no_grad() def internlm_xcomposser2_chat( self, tokenizer, query: str, image: torch.Tensor = None, history: List[Tuple[str, str]]=[], streamer=None, max_new_tokens: int = 1024, do_sample: bool = True, temperature: float = 1.0, top_p: float = 0.8, repetition_penalty: float=1.005, meta_instruction: str = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n' '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model ' 'that is developed by Shanghai AI Laboratory (上海人工智能实验室).' 'It is designed to be helpful, honest, and harmless.\n' '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the ' 'language chosen by the user such as English and 中文.\n' '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating ' 'responses effectively based on the provided image.'), **kwargs, ): # ipex-llm changes start: fix device and dtype conversion if image is None: inputs = self.build_inputs(tokenizer, query, history, meta_instruction) im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() else: image = self.encode_img(image) inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image, history, meta_instruction) new_inputs = {} for k, v in inputs.items(): if torch.is_tensor(v): if v.dtype.is_floating_point: new_inputs[k] = v.to(device=self.device, dtype=self.dtype) else: # input_ids, don't convert its dtype new_inputs[k] = v.to(device=self.device) else: new_inputs[k] = v inputs = new_inputs im_mask = im_mask.to(self.device) # ipex-llm changes end # also add end-of-assistant token in eos token id to avoid unnecessary generation eos_token_id = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] ] outputs = self.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, im_mask=im_mask, **kwargs, ) if image is None: outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] else: outputs = outputs[0].cpu().tolist() response = tokenizer.decode(outputs, skip_special_tokens=True) response = response.split('[UNUSED_TOKEN_145]')[0] history = history + [(query, response)] return response, history