LLM: Enable BigDL IPEX optimization for int4 (#10319)
Enable BigDL IPEX optimization for int4
This commit is contained in:
parent
5d7e044dbc
commit
0ded0b4b13
5 changed files with 276 additions and 36 deletions
|
|
@ -18,6 +18,8 @@ test_api:
|
||||||
- "optimize_model"
|
- "optimize_model"
|
||||||
- "pytorch_autocast_bf16"
|
- "pytorch_autocast_bf16"
|
||||||
# - "transformer_autocast_bf16"
|
# - "transformer_autocast_bf16"
|
||||||
|
# - "bigdl_ipex_bf16"
|
||||||
|
# - "bigdl_ipex_int4"
|
||||||
# - "ipex_fp16_gpu" # on Intel GPU
|
# - "ipex_fp16_gpu" # on Intel GPU
|
||||||
# - "bigdl_fp16_gpu" # on Intel GPU
|
# - "bigdl_fp16_gpu" # on Intel GPU
|
||||||
# - "transformer_int4_gpu" # on Intel GPU
|
# - "transformer_int4_gpu" # on Intel GPU
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,10 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
|
||||||
result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
|
result = run_transformer_int4_loadlowbit_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
|
||||||
elif test_api == 'transformer_autocast_bf16':
|
elif test_api == 'transformer_autocast_bf16':
|
||||||
result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
|
result = run_transformer_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
|
||||||
|
elif test_api == 'bigdl_ipex_bf16':
|
||||||
|
result = run_bigdl_ipex_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
|
||||||
|
elif test_api == 'bigdl_ipex_int4':
|
||||||
|
result = run_bigdl_ipex_int4(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
|
||||||
elif test_api == 'deepspeed_optimize_model_gpu':
|
elif test_api == 'deepspeed_optimize_model_gpu':
|
||||||
result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
|
result = run_deepspeed_optimize_model_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size)
|
||||||
|
|
||||||
|
|
@ -1079,6 +1083,148 @@ def run_transformer_autocast_bf16( repo_id,
|
||||||
actual_in_len, actual_out_len, load_time])
|
actual_in_len, actual_out_len, load_time])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_bigdl_ipex_bf16(repo_id,
|
||||||
|
local_model_hub,
|
||||||
|
in_out_pairs,
|
||||||
|
warm_up,
|
||||||
|
num_trials,
|
||||||
|
num_beams,
|
||||||
|
batch_size):
|
||||||
|
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
os.environ["BIGDL_OPT_IPEX"] = "true"
|
||||||
|
|
||||||
|
model_path = get_model_path(repo_id, local_model_hub)
|
||||||
|
# Load model in bf16,
|
||||||
|
# which convert the relevant layers in the model into BF16 format
|
||||||
|
st = time.perf_counter()
|
||||||
|
if repo_id in CHATGLM_IDS:
|
||||||
|
model = AutoModel.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
elif repo_id in LLAMA_IDS:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='bf16', trust_remote_code=True, torch_dtype=torch.bfloat16,
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
if not hasattr(model.config, "token_latency"):
|
||||||
|
model.config.token_latency = True
|
||||||
|
end = time.perf_counter()
|
||||||
|
load_time = end - st
|
||||||
|
print(">> loading of model costs {}s".format(load_time))
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
with torch.inference_mode(), torch.autocast("cpu"):
|
||||||
|
for in_out in in_out_pairs:
|
||||||
|
in_out_len = in_out.split("-")
|
||||||
|
in_len = int(in_out_len[0])
|
||||||
|
out_len = int(in_out_len[1])
|
||||||
|
# As different tokenizer has different encodings,
|
||||||
|
# in_len.txt maybe shorter than we need,
|
||||||
|
# use much longer context to make sure input length
|
||||||
|
test_length = min(in_len*2, 8192)
|
||||||
|
while test_length not in [32, 256, 1024, 2048, 8192]:
|
||||||
|
test_length = test_length * 2
|
||||||
|
input_str = open(f"prompt/{test_length}.txt", 'r').read()
|
||||||
|
# As different tokenizer has different encodings,
|
||||||
|
# slice the input_ids to ensure the prompt length is required length.
|
||||||
|
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||||
|
input_ids = input_ids[:, :in_len]
|
||||||
|
true_str = tokenizer.batch_decode(input_ids)[0]
|
||||||
|
input_list = [true_str] * batch_size
|
||||||
|
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
|
||||||
|
actual_in_len = input_ids.shape[1]
|
||||||
|
result[in_out] = []
|
||||||
|
for i in range(num_trials + warm_up):
|
||||||
|
st = time.perf_counter()
|
||||||
|
output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
|
||||||
|
num_beams=num_beams)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("model generate cost: " + str(end - st))
|
||||||
|
output = tokenizer.batch_decode(output_ids)
|
||||||
|
print(output[0])
|
||||||
|
actual_out_len = output_ids.shape[1] - actual_in_len
|
||||||
|
if i >= warm_up:
|
||||||
|
result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
|
||||||
|
actual_in_len, actual_out_len, load_time])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_bigdl_ipex_int4(repo_id,
|
||||||
|
local_model_hub,
|
||||||
|
in_out_pairs,
|
||||||
|
warm_up,
|
||||||
|
num_trials,
|
||||||
|
num_beams,
|
||||||
|
batch_size):
|
||||||
|
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
os.environ["BIGDL_OPT_IPEX"] = "true"
|
||||||
|
|
||||||
|
model_path = get_model_path(repo_id, local_model_hub)
|
||||||
|
|
||||||
|
st = time.perf_counter()
|
||||||
|
if repo_id in CHATGLM_IDS:
|
||||||
|
model = AutoModel.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
elif repo_id in LLAMA_IDS:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit='sym_int4', trust_remote_code=True, torch_dtype='auto',
|
||||||
|
use_cache=True, torchscript=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
if not hasattr(model.config, "token_latency"):
|
||||||
|
model.config.token_latency = True
|
||||||
|
end = time.perf_counter()
|
||||||
|
load_time = end - st
|
||||||
|
print(">> loading of model costs {}s".format(load_time))
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
with torch.inference_mode(), torch.autocast("cpu"):
|
||||||
|
for in_out in in_out_pairs:
|
||||||
|
in_out_len = in_out.split("-")
|
||||||
|
in_len = int(in_out_len[0])
|
||||||
|
out_len = int(in_out_len[1])
|
||||||
|
# As different tokenizer has different encodings,
|
||||||
|
# in_len.txt maybe shorter than we need,
|
||||||
|
# use much longer context to make sure input length
|
||||||
|
test_length = min(in_len*2, 8192)
|
||||||
|
while test_length not in [32, 256, 1024, 2048, 8192]:
|
||||||
|
test_length = test_length * 2
|
||||||
|
input_str = open(f"prompt/{test_length}.txt", 'r').read()
|
||||||
|
# As different tokenizer has different encodings,
|
||||||
|
# slice the input_ids to ensure the prompt length is required length.
|
||||||
|
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||||
|
input_ids = input_ids[:, :in_len]
|
||||||
|
true_str = tokenizer.batch_decode(input_ids)[0]
|
||||||
|
input_list = [true_str] * batch_size
|
||||||
|
input_ids = tokenizer(input_list, return_tensors="pt").input_ids
|
||||||
|
actual_in_len = input_ids.shape[1]
|
||||||
|
result[in_out] = []
|
||||||
|
for i in range(num_trials + warm_up):
|
||||||
|
st = time.perf_counter()
|
||||||
|
output_ids, total_list = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
|
||||||
|
num_beams=num_beams)
|
||||||
|
end = time.perf_counter()
|
||||||
|
print("model generate cost: " + str(end - st))
|
||||||
|
output = tokenizer.batch_decode(output_ids)
|
||||||
|
print(output[0])
|
||||||
|
actual_out_len = output_ids.shape[1] - actual_in_len
|
||||||
|
if i >= warm_up:
|
||||||
|
result[in_out].append([total_list[0], np.mean(total_list[1:]), 0,
|
||||||
|
actual_in_len, actual_out_len, load_time])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def run_deepspeed_optimize_model_gpu(repo_id,
|
def run_deepspeed_optimize_model_gpu(repo_id,
|
||||||
local_model_hub,
|
local_model_hub,
|
||||||
in_out_pairs,
|
in_out_pairs,
|
||||||
|
|
@ -1192,6 +1338,7 @@ def run_deepspeed_optimize_model_gpu(repo_id,
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
conf = OmegaConf.load(f'{current_dir}/config.yaml')
|
conf = OmegaConf.load(f'{current_dir}/config.yaml')
|
||||||
|
|
|
||||||
|
|
@ -611,13 +611,12 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||||
|
|
||||||
# using ipex optimizer before changing to bigdl linear
|
# using ipex optimizer before changing to bigdl linear
|
||||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
_enable_ipex = get_enable_ipex()
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
|
||||||
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
if device == "cpu":
|
||||||
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
|
|
||||||
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
model = _optimize_ipex(model)
|
model = _optimize_ipex(model, qtype)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
|
|
@ -686,12 +685,19 @@ def replace_func(m, target_m, func_name, new_func):
|
||||||
replace_func(sub_m, target_m, func_name, new_func)
|
replace_func(sub_m, target_m, func_name, new_func)
|
||||||
|
|
||||||
|
|
||||||
def _optimize_ipex(model):
|
def get_enable_ipex():
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
return _enable_ipex
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
|
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from bigdl.llm.transformers.convert_ipex import (
|
from bigdl.llm.transformers.convert_ipex import (
|
||||||
_ipex_optimize_model, _ipex_jit, _make_causal_mask,
|
_ipex_optimize_model, _ipex_jit, _make_causal_mask,
|
||||||
_llama_model_forward_4_35, convert_function, GLM_get_masks
|
_llama_model_forward_4_35, convert_function, GLM_get_masks,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model_convert_reference(model)
|
model = model_convert_reference(model)
|
||||||
|
|
@ -718,7 +724,7 @@ def _optimize_ipex(model):
|
||||||
# baichuan2
|
# baichuan2
|
||||||
rms_classes.append(type(model.model.layers[0].input_layernorm))
|
rms_classes.append(type(model.model.layers[0].input_layernorm))
|
||||||
|
|
||||||
_ipex_optimize_model(model, rms_classes)
|
model = _ipex_optimize_model(model, rms_classes, qtype)
|
||||||
return _ipex_jit(model)
|
return _ipex_jit(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,14 @@ from intel_extension_for_pytorch.transformers.optimize import (
|
||||||
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
|
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
|
||||||
_enable_tpp,
|
_enable_tpp,
|
||||||
_using_tpp,
|
_using_tpp,
|
||||||
|
_disable_tpp
|
||||||
)
|
)
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from bigdl.llm.transformers.convert import get_enable_ipex
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_rmsnorm(_model, supported_classes):
|
def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
|
||||||
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
|
||||||
for supported_class in supported_classes:
|
for supported_class in supported_classes:
|
||||||
lowering_class_cpu(
|
lowering_class_cpu(
|
||||||
|
|
@ -55,12 +59,12 @@ def _ipex_optimize_rmsnorm(_model, supported_classes):
|
||||||
supported_class,
|
supported_class,
|
||||||
_IPEXRMSNorm,
|
_IPEXRMSNorm,
|
||||||
_model.config,
|
_model.config,
|
||||||
tpp=False,
|
tpp=is_tpp,
|
||||||
woq=False,
|
woq=is_woq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_decoder(model):
|
def _ipex_optimize_decoder(model, is_tpp=False, is_woq=False):
|
||||||
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
||||||
_IPEXDecoderLayerRef
|
_IPEXDecoderLayerRef
|
||||||
)
|
)
|
||||||
|
|
@ -73,12 +77,12 @@ def _ipex_optimize_decoder(model):
|
||||||
supported_mlp_class,
|
supported_mlp_class,
|
||||||
_IPEXDecoderLayerCPU,
|
_IPEXDecoderLayerCPU,
|
||||||
model.config,
|
model.config,
|
||||||
tpp=True if _using_tpp() else False,
|
tpp=is_tpp,
|
||||||
woq=False,
|
woq=is_woq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_attention(model):
|
def _ipex_optimize_attention(model, is_tpp=False, is_woq=False):
|
||||||
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
||||||
_IPEXAttentionRef
|
_IPEXAttentionRef
|
||||||
)
|
)
|
||||||
|
|
@ -91,18 +95,47 @@ def _ipex_optimize_attention(model):
|
||||||
supported_mha_class,
|
supported_mha_class,
|
||||||
_IPEXAttentionCPU,
|
_IPEXAttentionCPU,
|
||||||
model.config,
|
model.config,
|
||||||
tpp=True if _using_tpp() else False,
|
tpp=is_tpp,
|
||||||
woq=False,
|
woq=is_woq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_model(model, rms_classes):
|
def _ipex_optimize_model(model, rms_classes, qtype):
|
||||||
_enable_tpp()
|
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
||||||
_ipex_optimize_rmsnorm(model, rms_classes)
|
from intel_extension_for_pytorch.transformers.optimize import ipex_quantization_flow
|
||||||
_ipex_optimize_attention(model)
|
|
||||||
_ipex_optimize_decoder(model)
|
is_woq = False
|
||||||
|
is_quantization = False
|
||||||
|
_disable_tpp()
|
||||||
|
if qtype == ggml_tensor_qtype["bf16"]:
|
||||||
|
_enable_tpp()
|
||||||
|
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
||||||
|
elif qtype == ggml_tensor_qtype["sym_int4"]:
|
||||||
|
is_quantization = True
|
||||||
|
is_woq = True
|
||||||
|
act_quant_mode_dict = {
|
||||||
|
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
|
||||||
|
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
|
||||||
|
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
|
||||||
|
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
|
||||||
|
}
|
||||||
|
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||||
|
weight_dtype=torch.quint4x2, # INT4
|
||||||
|
lowp_mode=ipex.quantization.WoqLowpMode.INT8,
|
||||||
|
act_quant_mode=act_quant_mode_dict["PER_IC_BLOCK"],
|
||||||
|
group_size=-1,
|
||||||
|
)
|
||||||
|
model = ipex_quantization_flow(model, torch.bfloat16, None, qconfig, None)
|
||||||
|
|
||||||
|
is_tpp = _using_tpp()
|
||||||
|
|
||||||
|
_ipex_optimize_rmsnorm(model, rms_classes, is_tpp=is_tpp, is_woq=is_woq)
|
||||||
|
_ipex_optimize_attention(model, is_tpp=is_tpp, is_woq=is_woq)
|
||||||
|
_ipex_optimize_decoder(model, is_tpp=is_tpp, is_woq=is_woq)
|
||||||
|
|
||||||
|
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _ipex_jit(model):
|
def _ipex_jit(model):
|
||||||
|
|
@ -152,9 +185,7 @@ def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
|
||||||
else: # discrete kv cache
|
else: # discrete kv cache
|
||||||
past_length = past_key_values[0][0].shape[-2]
|
past_length = past_key_values[0][0].shape[-2]
|
||||||
|
|
||||||
import os
|
_enable_ipex = get_enable_ipex()
|
||||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
|
||||||
# always call for jit
|
# always call for jit
|
||||||
if past_length or _enable_ipex:
|
if past_length or _enable_ipex:
|
||||||
full_attention_mask = torch.cat(
|
full_attention_mask = torch.cat(
|
||||||
|
|
@ -191,9 +222,7 @@ def _make_causal_mask(
|
||||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||||
|
|
||||||
import os
|
_enable_ipex = get_enable_ipex()
|
||||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
|
||||||
if _enable_ipex or past_key_values_length > 0:
|
if _enable_ipex or past_key_values_length > 0:
|
||||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) # noqa
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
||||||
LogitsProcessorList, StoppingCriteriaList
|
LogitsProcessorList, StoppingCriteriaList
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
# patch GenerationMixin.generate
|
# patch GenerationMixin.generate
|
||||||
from transformers import GenerationMixin
|
from transformers import GenerationMixin
|
||||||
|
|
@ -533,15 +534,16 @@ def speculative_generate(self,
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
past_key_values_storage = []
|
past_key_values_storage = []
|
||||||
|
|
||||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
from bigdl.llm.transformers.convert import get_enable_ipex
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
_enable_ipex = get_enable_ipex()
|
||||||
|
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
if not ((self.config.model_type == 'baichuan') or
|
if not ((self.config.model_type == 'baichuan') or
|
||||||
('llama' in self.config.model_type) or
|
('llama' in self.config.model_type) or
|
||||||
("mistral" in self.config.model_type) or
|
("mistral" in self.config.model_type) or
|
||||||
("qwen" in self.config.model_type) or
|
("qwen" in self.config.model_type) or
|
||||||
("chatglm" in self.config.model_type)):
|
("chatglm" in self.config.model_type)):
|
||||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
|
||||||
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
|
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
|
||||||
if "chatglm" in self.config.model_type:
|
if "chatglm" in self.config.model_type:
|
||||||
global query_group_size
|
global query_group_size
|
||||||
|
|
@ -579,6 +581,11 @@ def speculative_generate(self,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
|
if _enable_ipex:
|
||||||
|
output = CausalLMOutputWithPast(
|
||||||
|
logits=output[0],
|
||||||
|
past_key_values=output[1],
|
||||||
|
)
|
||||||
logits = output['logits']
|
logits = output['logits']
|
||||||
logits = logits[:, -1:]
|
logits = logits[:, -1:]
|
||||||
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
|
logits[:, -1, :] = logits_processor(current_input_ids, logits[:, -1, :])
|
||||||
|
|
@ -602,7 +609,7 @@ def speculative_generate(self,
|
||||||
draft_current_input_ids = current_input_ids
|
draft_current_input_ids = current_input_ids
|
||||||
# Target model KV cache to draft model
|
# Target model KV cache to draft model
|
||||||
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu' and (not _enable_ipex):
|
||||||
# init past_key_values_storage and assign initial fp32 value
|
# init past_key_values_storage and assign initial fp32 value
|
||||||
if step == 1:
|
if step == 1:
|
||||||
past_key_values_storage = \
|
past_key_values_storage = \
|
||||||
|
|
@ -652,6 +659,56 @@ def speculative_generate(self,
|
||||||
past_length = draft_past_key_values[0][0].size(2)
|
past_length = draft_past_key_values[0][0].size(2)
|
||||||
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
||||||
forward_args["position_ids"] = position_ids
|
forward_args["position_ids"] = position_ids
|
||||||
|
|
||||||
|
if _enable_ipex:
|
||||||
|
if any(keyword in self.config.model_type
|
||||||
|
for keyword in ["llama", "chatglm", "mistral"]):
|
||||||
|
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
||||||
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||||
|
position_ids = position_ids[:, :-draft_current_input_ids.size(0)]
|
||||||
|
draft_output = draft_model.trace_graph(
|
||||||
|
input_ids=draft_current_input_ids,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=draft_past_key_values,
|
||||||
|
)
|
||||||
|
elif self.config.model_type == "baichuan":
|
||||||
|
if self.config.hidden_size == 4096:
|
||||||
|
past_key_value_len = draft_past_key_values[0][0].shape[2]
|
||||||
|
seq_len = draft_current_input_ids.shape[1]
|
||||||
|
seq_len_with_past = seq_len + past_key_value_len
|
||||||
|
position_ids = torch.arange(past_key_value_len,
|
||||||
|
seq_len_with_past,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=draft_current_input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
||||||
|
draft_output = draft_model.trace_graph(
|
||||||
|
input_ids=draft_current_input_ids,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=draft_past_key_values,
|
||||||
|
)
|
||||||
|
elif self.config.hidden_size == 5120:
|
||||||
|
draft_output = draft_model.trace_graph(
|
||||||
|
input_ids=draft_current_input_ids,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
|
past_key_values=draft_past_key_values,
|
||||||
|
)
|
||||||
|
elif "qwen" in self.config.model_type:
|
||||||
|
draft_output = draft_model.trace_graph(
|
||||||
|
input_ids=draft_current_input_ids,
|
||||||
|
attention_mask=draft_attention_mask,
|
||||||
|
past_key_values=draft_past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX only supports \
|
||||||
|
Llama, Baichuan2, Mistral, ChatGLM and Qwen models currently.")
|
||||||
|
|
||||||
|
draft_output = CausalLMOutputWithPast(
|
||||||
|
logits=draft_output[0],
|
||||||
|
past_key_values=draft_output[1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
draft_output = draft_model(**forward_args)
|
draft_output = draft_model(**forward_args)
|
||||||
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
temp_input_ids = torch.cat((input_ids, generate_ids[:, :step],
|
||||||
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
|
draft_generate_ids[:, 1:step_draft+1]), dim=-1)
|
||||||
|
|
@ -848,7 +905,6 @@ def speculative_generate(self,
|
||||||
# Clean up target model KV cache
|
# Clean up target model KV cache
|
||||||
if max_of_max_matched != max_matched:
|
if max_of_max_matched != max_matched:
|
||||||
output_ids = output_ids[:, :max_matched]
|
output_ids = output_ids[:, :max_matched]
|
||||||
# For Qwen
|
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
cur_len = past_key_values[0][0].size(1)
|
cur_len = past_key_values[0][0].size(1)
|
||||||
delta = max_of_max_matched - max_matched
|
delta = max_of_max_matched - max_matched
|
||||||
|
|
@ -890,7 +946,7 @@ def speculative_generate(self,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Each iter assign new_matched kv_cache to past_key_values1
|
# Each iter assign new_matched kv_cache to past_key_values1
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu' and (not _enable_ipex):
|
||||||
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
_update_past_key_values_storage_cpu(self, past_key_values, past_key_values_storage,
|
||||||
original_draft_past_key_values,
|
original_draft_past_key_values,
|
||||||
_enable_ipex)
|
_enable_ipex)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue