451 lines
17 KiB
Python
451 lines
17 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://huggingface.co/internlm/internlm-chat-7b/blob/659ed911eec1e26810f9854f19c5ec27854e9cf3/modeling_internlm.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch InternLM model."""
|
|
from typing import Optional, Tuple, List
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
|
from einops import rearrange
|
|
|
|
|
|
def merge_qkv(module: torch.nn.Module):
|
|
merge_qkv_base(module, "InternLMAttention")
|
|
|
|
|
|
def internlm_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor]=None,
|
|
position_ids: Optional[torch.LongTensor]=None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
|
output_attentions: bool=False,
|
|
use_cache: bool=False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv = self.qkv_proj(hidden_states)
|
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
|
qkv = qkv.transpose(1, 2)
|
|
query_states, key_states, value_states = qkv.split([self.num_heads,
|
|
self.num_heads,
|
|
self.num_heads], dim=1)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
# IPEX-LLM OPT: fuse rope
|
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
|
import xe_addons
|
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
|
query_states, key_states)
|
|
else:
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, position_ids, "internlm"
|
|
)
|
|
|
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
|
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
|
|
self.num_heads, self.num_heads)
|
|
key_states, value_states = update_past_key_value(
|
|
past_key_value, key_states, value_states,
|
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
|
)
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
# IPEX-LLM OPT: sdp
|
|
attn_weights = None
|
|
attn_output = scaled_dot_product_attention(
|
|
query_states, key_states, value_states,
|
|
attention_mask, q_len == kv_seq_len
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def internlm2_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor]=None,
|
|
position_ids: Optional[torch.LongTensor]=None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
|
output_attentions: bool=False,
|
|
use_cache: bool=False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv_states = self.wqkv(hidden_states)
|
|
qkv_states = rearrange(
|
|
qkv_states,
|
|
"b q (h gs d) -> b q h gs d",
|
|
gs=2 + self.num_key_value_groups,
|
|
d=self.head_dim,
|
|
)
|
|
|
|
query_states = qkv_states[..., : self.num_key_value_groups, :]
|
|
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
|
|
key_states = qkv_states[..., -2, :]
|
|
value_states = qkv_states[..., -1, :]
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
# IPEX-LLM OPT: fuse rope
|
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
|
import xe_addons
|
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
|
query_states, key_states)
|
|
else:
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, position_ids, "internlm"
|
|
)
|
|
|
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
|
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
|
|
self.num_heads, self.num_key_value_heads)
|
|
key_states, value_states = update_past_key_value(
|
|
past_key_value, key_states, value_states,
|
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
|
)
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
# IPEX-LLM OPT: sdp
|
|
attn_weights = None
|
|
attn_output = scaled_dot_product_attention(
|
|
query_states, key_states, value_states,
|
|
attention_mask, q_len == kv_seq_len
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.wo(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def pre_process_attn_and_mlp(module: torch.nn.Module):
|
|
if module.__class__.__name__ == "InternLM2Attention":
|
|
module.wqkv_lora_scaling = module.wqkv.lora_scaling
|
|
module.wqkv_Plora_A = module.wqkv.Plora_A
|
|
module.wqkv_Plora_B = module.wqkv.Plora_B
|
|
del module.wqkv.Plora_A
|
|
del module.wqkv.Plora_B
|
|
|
|
module.wo_lora_scaling = module.wo.lora_scaling
|
|
module.wo_Plora_A = module.wo.Plora_A
|
|
module.wo_Plora_B = module.wo.Plora_B
|
|
del module.wo.Plora_A
|
|
del module.wo.Plora_B
|
|
|
|
elif module.__class__.__name__ == "InternLM2MLP":
|
|
module.w1_lora_scaling = module.w1.lora_scaling
|
|
module.w1_Plora_A = module.w1.Plora_A
|
|
module.w1_Plora_B = module.w1.Plora_B
|
|
del module.w1.Plora_A
|
|
del module.w1.Plora_B
|
|
|
|
module.w2_lora_scaling = module.w2.lora_scaling
|
|
module.w2_Plora_A = module.w2.Plora_A
|
|
module.w2_Plora_B = module.w2.Plora_B
|
|
del module.w2.Plora_A
|
|
del module.w2.Plora_B
|
|
|
|
module.w3_lora_scaling = module.w3.lora_scaling
|
|
module.w3_Plora_A = module.w3.Plora_A
|
|
module.w3_Plora_B = module.w3.Plora_B
|
|
del module.w3.Plora_A
|
|
del module.w3.Plora_B
|
|
|
|
|
|
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
|
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
|
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
|
invalidInputError(x.dim() == 3 and result.dim() == 3,
|
|
"`x` and `result` should have 3 dims")
|
|
if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0:
|
|
return result
|
|
else:
|
|
for start_idx, end_idx in im_mask:
|
|
result[:, start_idx:end_idx, :] += Plora_B(
|
|
Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
|
|
)
|
|
return result
|
|
|
|
|
|
def internlm_xcomposser2_model_forward_wrapper(origin_forward):
|
|
def internlm_xcomposser2_model_forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**kwargs
|
|
):
|
|
im_mask = kwargs.get('im_mask', None)
|
|
if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0:
|
|
# decoding or no image input, `im_mask` is not needed
|
|
kwargs['im_mask'] = []
|
|
else:
|
|
# replace im_mask with start_idx and end_idx to improve performance
|
|
im_mask = im_mask.cpu().flatten().tolist()
|
|
length = len(im_mask)
|
|
new_mask = []
|
|
i = 0
|
|
while i < length:
|
|
while i < length and not im_mask[i]:
|
|
i = i + 1
|
|
start_idx = i
|
|
while i < length and im_mask[i]:
|
|
i = i + 1
|
|
end_idx = i
|
|
if start_idx != end_idx:
|
|
new_mask.append((start_idx, end_idx))
|
|
kwargs['im_mask'] = new_mask
|
|
return origin_forward(
|
|
self=self,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
**kwargs
|
|
)
|
|
return internlm_xcomposser2_model_forward
|
|
|
|
|
|
def internlm_xcomposser2_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
device = hidden_states.device
|
|
|
|
qkv_states = self.wqkv(hidden_states)
|
|
qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
|
|
self.wqkv_Plora_A, self.wqkv_Plora_B)
|
|
|
|
qkv_states = rearrange(
|
|
qkv_states,
|
|
'b q (h gs d) -> b q h gs d',
|
|
gs=2 + self.num_key_value_groups,
|
|
d=self.head_dim,
|
|
)
|
|
|
|
query_states = qkv_states[..., :self.num_key_value_groups, :]
|
|
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
|
|
key_states = qkv_states[..., -2, :]
|
|
value_states = qkv_states[..., -1, :]
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
# IPEX-LLM OPT: fuse rope
|
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
|
# This fuse rope will get wrong result if context_length > max_position_embeddings (32768)
|
|
# we assume context_length <= 32768
|
|
import xe_addons
|
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
|
query_states, key_states)
|
|
else:
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, position_ids, "internlm")
|
|
|
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
|
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
|
|
self.num_heads, self.num_key_value_heads)
|
|
key_states, value_states = update_past_key_value(
|
|
past_key_value, key_states, value_states,
|
|
kv_seq_len, use_quantize_kv, device
|
|
)
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
# IPEX-LLM OPT: sdp
|
|
attn_weights = None
|
|
attn_output = scaled_dot_product_attention(
|
|
query_states, key_states, value_states,
|
|
attention_mask, q_len == kv_seq_len
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output_2 = self.wo(attn_output)
|
|
|
|
attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling,
|
|
self.wo_Plora_A, self.wo_Plora_B)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def internlm_xcomposser2_mlp_forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
|
):
|
|
w1 = self.w1(x)
|
|
w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B)
|
|
w3 = self.w3(x)
|
|
w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B)
|
|
x = self.act_fn(w1) * w3
|
|
w2 = self.w2(x)
|
|
w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
|
|
return w2
|
|
|
|
|
|
@torch.no_grad()
|
|
def internlm_xcomposser2_chat(
|
|
self,
|
|
tokenizer,
|
|
query: str,
|
|
image: torch.Tensor = None,
|
|
history: List[Tuple[str, str]]=[],
|
|
streamer=None,
|
|
max_new_tokens: int = 1024,
|
|
do_sample: bool = True,
|
|
temperature: float = 1.0,
|
|
top_p: float = 0.8,
|
|
repetition_penalty: float=1.005,
|
|
meta_instruction:
|
|
str = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
|
|
'- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model '
|
|
'that is developed by Shanghai AI Laboratory (上海人工智能实验室).'
|
|
'It is designed to be helpful, honest, and harmless.\n'
|
|
'- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the '
|
|
'language chosen by the user such as English and 中文.\n'
|
|
'- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating '
|
|
'responses effectively based on the provided image.'),
|
|
**kwargs,
|
|
):
|
|
# ipex-llm changes start: fix device and dtype conversion
|
|
if image is None:
|
|
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
|
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
|
else:
|
|
image = self.encode_img(image)
|
|
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
|
history, meta_instruction)
|
|
|
|
new_inputs = {}
|
|
for k, v in inputs.items():
|
|
if torch.is_tensor(v):
|
|
if v.dtype.is_floating_point:
|
|
new_inputs[k] = v.to(device=self.device, dtype=self.dtype)
|
|
else:
|
|
# input_ids, don't convert its dtype
|
|
new_inputs[k] = v.to(device=self.device)
|
|
else:
|
|
new_inputs[k] = v
|
|
inputs = new_inputs
|
|
im_mask = im_mask.to(self.device)
|
|
# ipex-llm changes end
|
|
|
|
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
|
eos_token_id = [
|
|
tokenizer.eos_token_id,
|
|
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
|
]
|
|
outputs = self.generate(
|
|
**inputs,
|
|
streamer=streamer,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=do_sample,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
eos_token_id=eos_token_id,
|
|
repetition_penalty=repetition_penalty,
|
|
im_mask=im_mask,
|
|
**kwargs,
|
|
)
|
|
if image is None:
|
|
outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
|
|
else:
|
|
outputs = outputs[0].cpu().tolist()
|
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
|
response = response.split('[UNUSED_TOKEN_145]')[0]
|
|
history = history + [(query, response)]
|
|
return response, history
|