support running internlm xcomposer2 on gpu and add sdp optimization (#11115)

This commit is contained in:
Yishuo Wang 2024-05-23 17:26:24 +08:00 committed by GitHub
parent c5e8b90c8d
commit 37b98a531f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 106 additions and 17 deletions

View file

@ -44,11 +44,11 @@ import transformers
import importlib.util import importlib.util
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from .utils import logger, get_cur_qtype_and_imatrix from .utils import logger, get_cur_qtype_and_imatrix
from typing import Union
import numpy as np import numpy as np
import os import os
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from types import MethodType
import subprocess import subprocess
import sys import sys
@ -1228,6 +1228,8 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward) convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward) convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
model.chat = MethodType(internlm_xcomposser2_chat, model)
elif model.config.model_type == "qwen": elif model.config.model_type == "qwen":
if hasattr(model.config, "visual"): if hasattr(model.config, "visual"):
# for Qwen-VL-Chat # for Qwen-VL-Chat

View file

@ -37,7 +37,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, List
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -47,9 +47,13 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31 append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from einops import rearrange from einops import rearrange
import os import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
@ -347,6 +351,7 @@ def internlm_xcomposser2_attention_forward(
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
qkv_states = self.wqkv(hidden_states) qkv_states = self.wqkv(hidden_states)
qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling, qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
@ -375,26 +380,45 @@ def internlm_xcomposser2_attention_forward(
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, "internlm") query_states, key_states, cos, sin, position_ids, "internlm")
if past_key_value is not None: # IPEX-LLM OPT: kv cache and quantzie kv cache
# reuse k, v, self_attention use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states, value_states = update_past_key_value(
value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups) # IPEX-LLM OPT: sdp
value_states = repeat_kv(value_states, self.num_key_value_groups) if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import linear_q4_0
if use_quantize_kv:
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import linear_q4_0
if use_quantize_kv:
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
else:
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose( attn_weights = torch.matmul(query_states,
2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax( attn_weights = nn.functional.softmax(attn_weights,
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -423,3 +447,66 @@ def internlm_xcomposser2_mlp_forward(
w2 = self.w2(x) w2 = self.w2(x)
w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B) w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
return w2 return w2
@torch.no_grad()
def internlm_xcomposser2_chat(
self,
tokenizer,
query: str,
image: torch.Tensor = None,
history: List[Tuple[str, str]]=[],
streamer=None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float=1.005,
meta_instruction:
str = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
'- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model'
'that is developed by Shanghai AI Laboratory (上海人工智能实验室).'
'It is designed to be helpful, honest, and harmless.\n'
'- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the'
'language chosen by the user such as English and 中文.\n'
'- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating'
'responses effectively based on the provided image.'),
**kwargs,
):
if image is None:
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
else:
image = self.encode_img(image)
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
history, meta_instruction)
inputs = {
k: v.to(self.device)
for k, v in inputs.items() if torch.is_tensor(v)
}
im_mask = im_mask.to(self.device)
# also add end-of-assistant token in eos token id to avoid unnecessary generation
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
im_mask=im_mask,
**kwargs,
)
if image is None:
outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
else:
outputs = outputs[0].cpu().tolist()
response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split('[UNUSED_TOKEN_145]')[0]
history = history + [(query, response)]
return response, history