LLM: Enable BigDL IPEX optimization for int4 (#10319)

Enable BigDL IPEX optimization for int4
This commit is contained in:
Xiangyu Tian 2024-03-12 17:08:50 +08:00 committed by GitHub
parent 5d7e044dbc
commit 0ded0b4b13
5 changed files with 276 additions and 36 deletions

View file

@ -18,6 +18,8 @@ test_api:
- "optimize_model"
- "pytorch_autocast_bf16"
# - "transformer_autocast_bf16"
# - "bigdl_ipex_bf16"
# - "bigdl_ipex_int4"
# - "ipex_fp16_gpu" # on Intel GPU
# - "bigdl_fp16_gpu" # on Intel GPU
# - "transformer_int4_gpu" # on Intel GPU

View file

@ -92,6 +92,10 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
elif test_api == 'transformer_autocast_bf16':
result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
elif test_api == 'bigdl_ipex_bf16':
result = run_bigdl_ipex_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
elif test_api == 'bigdl_ipex_int4':
result = run_bigdl_ipex_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
elif test_api == 'deepspeed_optimize_model_gpu':
result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
@ -1079,6 +1083,148 @@ def run_transformer_autocast_bf16( repo_id,
actual_in_len, actual_out_len, load_time])
return result
def run_bigdl_ipex_bf16(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
batch_size):
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer
os.environ["BIGDL_OPT_IPEX"] = "true"
model_path = get_model_path(repo_id, local_model_hub)
# Load model in bf16,
# which convert the relevant layers in the model into BF16 format
st = time.perf_counter()
if repo_id in CHATGLM_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
use_cache=True, torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
use_cache=True, torchscript=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
use_cache=True, torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if not hasattr(model.config, "token_latency"):
model.config.token_latency = True
end = time.perf_counter()
load_time = end - st
print(">> loading of model costs {}s".format(load_time))
result = {}
with torch.inference_mode(), torch.autocast("cpu"):
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
# As different tokenizer has different encodings,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
test_length = test_length * 2
input_str = open(f"prompt/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_list = [true_str] * batch_size
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
num_beams=num_beams)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
actual_in_len, actual_out_len, load_time])
return result
def run_bigdl_ipex_int4(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
batch_size):
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer
os.environ["BIGDL_OPT_IPEX"] = "true"
model_path = get_model_path(repo_id, local_model_hub)
st = time.perf_counter()
if repo_id in CHATGLM_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
use_cache=True, torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
use_cache=True, torchscript=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
use_cache=True, torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if not hasattr(model.config, "token_latency"):
model.config.token_latency = True
end = time.perf_counter()
load_time = end - st
print(">> loading of model costs {}s".format(load_time))
result = {}
with torch.inference_mode(), torch.autocast("cpu"):
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
# As different tokenizer has different encodings,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
test_length = test_length * 2
input_str = open(f"prompt/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_list = [true_str] * batch_size
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
num_beams=num_beams)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
actual_in_len, actual_out_len, load_time])
return result
def run_deepspeed_optimize_model_gpu(repo_id,
local_model_hub,
in_out_pairs,
@ -1192,6 +1338,7 @@ def run_deepspeed_optimize_model_gpu(repo_id,
torch.xpu.empty_cache()
return result
if __name__ == '__main__':
from omegaconf import OmegaConf
conf = OmegaConf.load(f'{current_dir}/config.yaml')

View file

@ -611,13 +611,12 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
# using ipex optimizer before changing to bigdl linear
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
_enable_ipex = get_enable_ipex()
if device == "cpu":
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
if _enable_ipex:
model = _optimize_ipex(model)
model = _optimize_ipex(model, qtype)
return model
if optimize_model:
@ -686,12 +685,19 @@ def replace_func(m, target_m, func_name, new_func):
replace_func(sub_m, target_m, func_name, new_func)
def _optimize_ipex(model):
def get_enable_ipex():
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
return _enable_ipex
def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_model, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35, convert_function, GLM_get_masks
_llama_model_forward_4_35, convert_function, GLM_get_masks,
)
model = model_convert_reference(model)
@ -718,7 +724,7 @@ def _optimize_ipex(model):
# baichuan2
rms_classes.append(type(model.model.layers[0].input_layernorm))
_ipex_optimize_model(model, rms_classes)
model = _ipex_optimize_model(model, rms_classes, qtype)
return _ipex_jit(model)

View file

@ -44,10 +44,14 @@ from intel_extension_for_pytorch.transformers.optimize import (
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
_enable_tpp,
_using_tpp,
_disable_tpp
)
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import get_enable_ipex
import os
def _ipex_optimize_rmsnorm(_model, supported_classes):
def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
for supported_class in supported_classes:
lowering_class_cpu(
@ -55,12 +59,12 @@ def _ipex_optimize_rmsnorm(_model, supported_classes):
supported_class,
_IPEXRMSNorm,
_model.config,
tpp=False,
woq=False,
tpp=is_tpp,
woq=is_woq,
)
def _ipex_optimize_decoder(model):
def _ipex_optimize_decoder(model, is_tpp=False, is_woq=False):
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
_IPEXDecoderLayerRef
)
@ -73,12 +77,12 @@ def _ipex_optimize_decoder(model):
supported_mlp_class,
_IPEXDecoderLayerCPU,
model.config,
tpp=True if _using_tpp() else False,
woq=False,
tpp=is_tpp,
woq=is_woq,
)
def _ipex_optimize_attention(model):
def _ipex_optimize_attention(model, is_tpp=False, is_woq=False):
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
_IPEXAttentionRef
)
@ -91,18 +95,47 @@ def _ipex_optimize_attention(model):
supported_mha_class,
_IPEXAttentionCPU,
model.config,
tpp=True if _using_tpp() else False,
woq=False,
tpp=is_tpp,
woq=is_woq,
)
def _ipex_optimize_model(model, rms_classes):
_enable_tpp()
def _ipex_optimize_model(model, rms_classes, qtype):
import intel_extension_for_pytorch as ipex
ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_rmsnorm(model, rms_classes)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
from intel_extension_for_pytorch.transformers.optimize import ipex_quantization_flow
is_woq = False
is_quantization = False
_disable_tpp()
if qtype == ggml_tensor_qtype["bf16"]:
_enable_tpp()
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
elif qtype == ggml_tensor_qtype["sym_int4"]:
is_quantization = True
is_woq = True
act_quant_mode_dict = {
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
}
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=torch.quint4x2, # INT4
lowp_mode=ipex.quantization.WoqLowpMode.INT8,
act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"],
group_size=-1,
)
model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None)
is_tpp = _using_tpp()
_ipex_optimize_rmsnorm(model, rms_classes, is_tpp=is_tpp, is_woq=is_woq)
_ipex_optimize_attention(model, is_tpp=is_tpp, is_woq=is_woq)
_ipex_optimize_decoder(model, is_tpp=is_tpp, is_woq=is_woq)
model.register_forward_hook(output_hook, with_kwargs=True)
return model
def _ipex_jit(model):
@ -152,9 +185,7 @@ def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
else: # discrete kv cache
past_length = past_key_values[0][0].shape[-2]
import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = get_enable_ipex()
# always call for jit
if past_length or _enable_ipex:
full_attention_mask = torch.cat(
@ -191,9 +222,7 @@ def _make_causal_mask(
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
import os
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
_enable_ipex = get_enable_ipex()
if _enable_ipex or past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa

View file

@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
from transformers import top_k_top_p_filtering, GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
from bigdl.llm.utils.common import invalidInputError
from transformers.modeling_outputs import CausalLMOutputWithPast
# patch GenerationMixin.generate
from transformers import GenerationMixin
@ -533,15 +534,16 @@ def speculative_generate(self,
past_key_values = None
past_key_values_storage = []
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
from bigdl.llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()
if _enable_ipex:
if not ((self.config.model_type == 'baichuan') or
('llama' in self.config.model_type) or
("mistral" in self.config.model_type) or
("qwen" in self.config.model_type) or
("chatglm" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
@ -579,6 +581,11 @@ def speculative_generate(self,
attention_mask=attention_mask,
return_dict=True,
use_cache=True)
if _enable_ipex:
output = CausalLMOutputWithPast(
logits=output[0],
past_key_values=output[1],
)
logits = output['logits']
logits = logits[:, -1:]
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
@ -602,7 +609,7 @@ def speculative_generate(self,
draft_current_input_ids = current_input_ids
# Target model KV cache to draft model
if self.device.type == 'cpu':
if self.device.type == 'cpu' and (not _enable_ipex):
# init past_key_values_storage and assign initial fp32 value
if step == 1:
past_key_values_storage = \
@ -652,7 +659,57 @@ def speculative_generate(self,
past_length = draft_past_key_values[0][0].size(2)
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
forward_args["position_ids"] = position_ids
draft_output = draft_model(**forward_args)
if _enable_ipex:
if any(keyword in self.config.model_type
for keyword in ["llama", "chatglm", "mistral"]):
past_key_value_len = draft_past_key_values[0][0].shape[2]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
elif self.config.model_type == "baichuan":
if self.config.hidden_size == 4096:
past_key_value_len = draft_past_key_values[0][0].shape[2]
seq_len = draft_current_input_ids.shape[1]
seq_len_with_past = seq_len + past_key_value_len
position_ids = torch.arange(past_key_value_len,
seq_len_with_past,
dtype=torch.long,
device=draft_current_input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
position_ids=position_ids,
past_key_values=draft_past_key_values,
)
elif self.config.hidden_size == 5120:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
past_key_values=draft_past_key_values,
)
elif "qwen" in self.config.model_type:
draft_output = draft_model.trace_graph(
input_ids=draft_current_input_ids,
attention_mask=draft_attention_mask,
past_key_values=draft_past_key_values,
)
else:
invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
draft_output = CausalLMOutputWithPast(
logits=draft_output[0],
past_key_values=draft_output[1],
)
else:
draft_output = draft_model(**forward_args)
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
logits = draft_output['logits']
@ -848,7 +905,6 @@ def speculative_generate(self,
# Clean up target model KV cache
if max_of_max_matched != max_matched:
output_ids = output_ids[:, :max_matched]
# For Qwen
if _enable_ipex:
cur_len = past_key_values[0][0].size(1)
delta = max_of_max_matched - max_matched
@ -890,7 +946,7 @@ def speculative_generate(self,
]
# Each iter assign new_matched kv_cache to past_key_values1
if self.device.type == 'cpu':
if self.device.type == 'cpu' and (not _enable_ipex):
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
original_draft_past_key_values,
_enable_ipex)