LLM: enable chatglm3-6b target_model ipex (#10085)
* init * always make casual_mask * not return last tensor * update * optimize_model = False * enable optimized=False * enable optimized_model=true * speed_up ipex target_model * remove if True * use group_size * update python style * update * update
This commit is contained in:
parent
177273c1a4
commit
f2417e083c
3 changed files with 204 additions and 71 deletions
|
|
@ -518,6 +518,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
f"format......")
|
f"format......")
|
||||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||||
|
|
||||||
|
# using ipex optimizer before changing to bigdl linear
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
||||||
|
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
|
||||||
|
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
||||||
|
if _enable_ipex:
|
||||||
|
model = _optimize_ipex(model)
|
||||||
|
return model
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_pre(model)
|
model = _optimize_pre(model)
|
||||||
|
|
||||||
|
|
@ -543,14 +553,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
# Do nothing here for weights are empty.
|
# Do nothing here for weights are empty.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
|
||||||
_enable_ipex = _enable_ipex and (qtype == ggml_tensor_qtype["bf16"])
|
|
||||||
if (device == "cpu") and (qtype == ggml_tensor_qtype["bf16"]):
|
|
||||||
logger.info(f"BIGDL_OPT_IPEX: {_enable_ipex}")
|
|
||||||
if _enable_ipex:
|
|
||||||
model = _optimize_ipex(model)
|
|
||||||
return model
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_post(model, lightweight_bmm)
|
model = _optimize_post(model, lightweight_bmm)
|
||||||
return model
|
return model
|
||||||
|
|
@ -590,13 +592,17 @@ def _optimize_ipex(model):
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from bigdl.llm.transformers.convert_ipex import (
|
from bigdl.llm.transformers.convert_ipex import (
|
||||||
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
||||||
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
|
_ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
||||||
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
|
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
_llama_model_forward_4_35)
|
_llama_model_forward_4_35)
|
||||||
model = model_convert_reference(model)
|
model = model_convert_reference(model)
|
||||||
|
if model.config.architectures is not None \
|
||||||
|
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||||
|
convert_function(model.transformer, "get_masks", GLM_get_masks)
|
||||||
|
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
||||||
_ipex_optimize_rmsnorm(model)
|
_ipex_optimize_rmsnorm(model)
|
||||||
_ipex_optimize_attention(model)
|
_ipex_optimize_attention(model)
|
||||||
_ipex_optimize_decoder(model)
|
_ipex_optimize_decoder(model)
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,8 @@ def _ipex_jit(model):
|
||||||
sample_inputs = (
|
sample_inputs = (
|
||||||
get_dummy_input(model, return_dict=True)
|
get_dummy_input(model, return_dict=True)
|
||||||
)
|
)
|
||||||
|
if "return_last_logit" in sample_inputs:
|
||||||
|
del sample_inputs["return_last_logit"]
|
||||||
with torch.no_grad(), torch.cpu.amp.autocast(
|
with torch.no_grad(), torch.cpu.amp.autocast(
|
||||||
enabled=True
|
enabled=True
|
||||||
):
|
):
|
||||||
|
|
@ -159,6 +161,47 @@ def _ipex_jit(model):
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_function(m, func_name, new_function):
|
||||||
|
bound_method = new_function.__get__(m, m.__class__)
|
||||||
|
setattr(m, func_name, bound_method)
|
||||||
|
|
||||||
|
|
||||||
|
def GLM_get_masks(self, input_ids, past_key_values, padding_mask=None):
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
full_attention_mask = torch.ones(
|
||||||
|
batch_size, seq_length, seq_length, device=input_ids.device
|
||||||
|
)
|
||||||
|
full_attention_mask.tril_()
|
||||||
|
past_length = 0
|
||||||
|
if past_key_values:
|
||||||
|
if len(past_key_values[0]) != 4: # not discrete kv cache
|
||||||
|
past_length = past_key_values[0][0].shape[0]
|
||||||
|
else: # discrete kv cache
|
||||||
|
past_length = past_key_values[0][0].shape[-2]
|
||||||
|
|
||||||
|
import os
|
||||||
|
_enable_ipex = os.getenv("BIGDL_OPT_IPEX")
|
||||||
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
|
# always call for jit
|
||||||
|
if past_length or _enable_ipex:
|
||||||
|
full_attention_mask = torch.cat(
|
||||||
|
(
|
||||||
|
torch.ones(
|
||||||
|
batch_size, seq_length, past_length, device=input_ids.device
|
||||||
|
),
|
||||||
|
full_attention_mask,
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
if padding_mask is not None:
|
||||||
|
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
||||||
|
# if not past_length and padding_mask is not None:
|
||||||
|
# full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
||||||
|
full_attention_mask = (full_attention_mask < 0.5).bool()
|
||||||
|
full_attention_mask.unsqueeze_(1)
|
||||||
|
return full_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_causal_mask(
|
def _make_causal_mask(
|
||||||
input_ids_shape: torch.Size,
|
input_ids_shape: torch.Size,
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
# patch GenerationMixin.generate
|
# patch GenerationMixin.generate
|
||||||
from transformers import GenerationMixin
|
from transformers import GenerationMixin
|
||||||
original_generate = GenerationMixin.generate
|
original_generate = GenerationMixin.generate
|
||||||
|
query_group_size = 16
|
||||||
logger = logging.getLogger("bigdl.llm.speculative")
|
logger = logging.getLogger("bigdl.llm.speculative")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -131,62 +131,98 @@ def clear_benchmarks(self):
|
||||||
def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
max_new_tokens, _enable_ipex=False):
|
max_new_tokens, _enable_ipex=False):
|
||||||
past_key_values_storage = []
|
past_key_values_storage = []
|
||||||
|
# init ipex_past_key_values
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
ipex_past_key_values = []
|
ipex_past_key_values = []
|
||||||
cur_len = past_key_values[0][0].size(1)
|
cur_len = past_key_values[0][0].size(1)
|
||||||
ipex_past_key_values = [
|
if self.config.model_type == "chatglm":
|
||||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
len0 = past_key_values[0][1].size(0) # seq max length
|
||||||
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
len1 = past_key_values[0][1].size(1)
|
||||||
for pkv in past_key_values
|
len2 = past_key_values[0][1].size(2)
|
||||||
]
|
len3 = past_key_values[0][1].size(3)
|
||||||
|
for pkv in past_key_values:
|
||||||
for i in range(len(past_key_values)):
|
key = pkv[1]
|
||||||
if not _enable_ipex:
|
value = pkv[2]
|
||||||
len0 = past_key_values[i][0].size(0)
|
key = key.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||||
len1 = past_key_values[i][0].size(1)
|
key = key.expand(-1, -1, query_group_size, -1, -1)
|
||||||
len2 = past_key_values[i][0].size(2)
|
key = key.contiguous().view(len1, len2 * query_group_size,
|
||||||
len3 = past_key_values[i][0].size(3)
|
len0, len3).permute(2, 0, 1, 3)
|
||||||
|
value = value.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||||
|
value = value.expand(-1, -1, query_group_size, -1, -1)
|
||||||
|
value = value.contiguous().view(len1, len2 * query_group_size,
|
||||||
|
len0, len3).permute(2, 0, 1, 3)
|
||||||
|
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
|
||||||
|
ipex_past_key_values.append(list)
|
||||||
else:
|
else:
|
||||||
len0 = past_key_values[i][1].size(1)
|
ipex_past_key_values = [
|
||||||
len1 = past_key_values[i][1].size(2)
|
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||||
len2 = past_key_values[i][0].size(2) # seq length
|
pkv[2].permute(1, 2, 0, 3)[:, :, :cur_len, :]]
|
||||||
len3 = past_key_values[i][1].size(3)
|
for pkv in past_key_values
|
||||||
if self.config.model_type == "qwen":
|
]
|
||||||
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
if not _enable_ipex:
|
||||||
dtype=torch.float32)
|
len0 = past_key_values[0][0].size(0)
|
||||||
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
len1 = past_key_values[0][0].size(1)
|
||||||
dtype=torch.float32)
|
len2 = past_key_values[0][0].size(2)
|
||||||
k0 = k0.transpose(1, 2)
|
len3 = past_key_values[0][0].size(3)
|
||||||
v0 = v0.transpose(1, 2)
|
for i in range(len(past_key_values)):
|
||||||
past_key_values_storage.append((k0, v0))
|
if self.config.model_type == "qwen":
|
||||||
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||||
torch.float32)
|
dtype=torch.float32)
|
||||||
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
v0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
|
||||||
torch.float32)
|
dtype=torch.float32)
|
||||||
elif self.config.model_type == "chatglm":
|
k0 = k0.transpose(1, 2)
|
||||||
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
v0 = v0.transpose(1, 2)
|
||||||
dtype=torch.float32)
|
past_key_values_storage.append((k0, v0))
|
||||||
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
past_key_values_storage[i][0][:, :len1, :, :] = past_key_values[i][0].to(
|
||||||
dtype=torch.float32)
|
torch.float32)
|
||||||
k0 = k0.permute(2, 0, 1, 3)
|
past_key_values_storage[i][1][:, :len1, :, :] = past_key_values[i][1].to(
|
||||||
v0 = v0.permute(2, 0, 1, 3)
|
torch.float32)
|
||||||
past_key_values_storage.append((k0, v0))
|
elif self.config.model_type == "chatglm":
|
||||||
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
k0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||||
torch.float32)
|
dtype=torch.float32)
|
||||||
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
v0 = torch.ones(len1, len2, len0 + max_new_tokens, len3,
|
||||||
torch.float32)
|
dtype=torch.float32)
|
||||||
else:
|
k0 = k0.permute(2, 0, 1, 3)
|
||||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
v0 = v0.permute(2, 0, 1, 3)
|
||||||
dtype=torch.float32)
|
past_key_values_storage.append((k0, v0))
|
||||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
past_key_values_storage[i][0][:len0, :, :, :] = past_key_values[i][0].to(
|
||||||
dtype=torch.float32)
|
torch.float32)
|
||||||
past_key_values_storage.append((k0, v0))
|
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
|
||||||
if not _enable_ipex:
|
torch.float32)
|
||||||
|
else:
|
||||||
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
past_key_values_storage[i][0][:, :, :len2, :] = past_key_values[i][0].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
past_key_values_storage[i][1][:, :, :len2, :] = past_key_values[i][1].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
|
else:
|
||||||
|
len0 = past_key_values[0][1].size(1)
|
||||||
|
len1 = past_key_values[0][1].size(2)
|
||||||
|
len2 = past_key_values[0][0].size(2) # seq length
|
||||||
|
len3 = past_key_values[0][1].size(3)
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
if self.config.model_type == "chatglm":
|
||||||
|
k0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len1 * query_group_size, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
k0 = k0.permute(2, 0, 1, 3)
|
||||||
|
v0 = v0.permute(2, 0, 1, 3)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
|
past_key_values_storage[i][0][:len2, :, :, :] = ipex_past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
else:
|
else:
|
||||||
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
|
past_key_values_storage[i][0][:, :, :len2, :] = ipex_past_key_values[i][0].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
|
past_key_values_storage[i][1][:, :, :len2, :] = ipex_past_key_values[i][1].to(
|
||||||
|
|
@ -195,7 +231,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
return past_key_values_storage
|
return past_key_values_storage
|
||||||
|
|
||||||
|
|
||||||
def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_storage):
|
def _prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||||
|
past_key_values_storage, _enable_ipex):
|
||||||
tmp_past_key_values = []
|
tmp_past_key_values = []
|
||||||
for i in range(len(past_key_values)):
|
for i in range(len(past_key_values)):
|
||||||
if self.config.model_type == "qwen":
|
if self.config.model_type == "qwen":
|
||||||
|
|
@ -204,7 +241,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, past_key_values_st
|
||||||
v0 = past_key_values_storage[i][1][:, :len1, :, :]
|
v0 = past_key_values_storage[i][1][:, :len1, :, :]
|
||||||
tmp_past_key_values.append((k0, v0))
|
tmp_past_key_values.append((k0, v0))
|
||||||
elif self.config.model_type == "chatglm":
|
elif self.config.model_type == "chatglm":
|
||||||
len0 = past_key_values[0][0].size(0)
|
if not _enable_ipex:
|
||||||
|
len0 = past_key_values[0][0].size(0)
|
||||||
|
else:
|
||||||
|
len0 = past_key_values[0][0].size(1)
|
||||||
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
k0 = past_key_values_storage[i][0][:len0, :, :, :]
|
||||||
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
v0 = past_key_values_storage[i][1][:len0, :, :, :]
|
||||||
tmp_past_key_values.append((k0, v0))
|
tmp_past_key_values.append((k0, v0))
|
||||||
|
|
@ -244,15 +284,41 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
||||||
else:
|
else:
|
||||||
size = original_draft_past_key_values[i][0].size(2)
|
size = original_draft_past_key_values[i][0].size(2)
|
||||||
size1 = past_key_values[i][0].size(1)
|
size1 = past_key_values[i][0].size(1)
|
||||||
delta_past_key = \
|
if self.config.model_type == "chatglm":
|
||||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
size = original_draft_past_key_values[0][0].size(0)
|
||||||
delta_past_value = \
|
size1 = past_key_values[0][0].size(1)
|
||||||
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
len0 = past_key_values[0][1].size(0) # seq max_length
|
||||||
|
len1 = past_key_values[0][1].size(1)
|
||||||
|
len2 = past_key_values[0][1].size(2)
|
||||||
|
len3 = past_key_values[0][1].size(3)
|
||||||
|
key0 = torch.ones(size1-size, len1, len2, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
value0 = torch.ones(size1-size, len1, len2, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
key0 = past_key_values[i][1][size:size1, :, :, :]
|
||||||
|
value0 = past_key_values[i][2][size:size1, :, :, :]
|
||||||
|
key = key0.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||||
|
key = key.expand(-1, -1, query_group_size, -1, -1)
|
||||||
|
key = key.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
|
||||||
|
key = key.permute(2, 0, 1, 3)
|
||||||
|
value = value0.permute(1, 2, 0, 3).unsqueeze(-3)
|
||||||
|
value = value.expand(-1, -1, query_group_size, -1, -1)
|
||||||
|
value = value.contiguous().view(len1, len2 * query_group_size, size1-size, len3)
|
||||||
|
value = value.permute(2, 0, 1, 3)
|
||||||
|
past_key_values_storage[i][0][size:size1, :, :, :] = \
|
||||||
|
key.to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||||
|
value.to(torch.float32)
|
||||||
|
else:
|
||||||
|
delta_past_key = \
|
||||||
|
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||||
|
delta_past_value = \
|
||||||
|
past_key_values[i][2][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||||
|
|
||||||
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
past_key_values_storage[i][0][:, :, size:size1, :] = \
|
||||||
delta_past_key.to(torch.float32)
|
delta_past_key.to(torch.float32)
|
||||||
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
past_key_values_storage[i][1][:, :, size:size1, :] = \
|
||||||
delta_past_value.to(torch.float32)
|
delta_past_value.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
@ -372,10 +438,14 @@ def speculative_generate(self,
|
||||||
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
_enable_ipex = (_enable_ipex is not None) and (_enable_ipex.lower() == "true")
|
||||||
if _enable_ipex:
|
if _enable_ipex:
|
||||||
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
if not ((self.config.model_type == 'baichuan' and self.config.hidden_size == 5120) or
|
||||||
('llama' in self.config.model_type) or
|
('llama' in self.config.model_type) or ("chatglm" in self.config.model_type) or
|
||||||
("mistral" in self.config.model_type)):
|
("mistral" in self.config.model_type)):
|
||||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||||
Llama, Baichuan2-13b and Mistral models currently.")
|
Llama, Baichuan2-13b and Mistral models currently.")
|
||||||
|
if "chatglm" in self.config.model_type:
|
||||||
|
global query_group_size
|
||||||
|
query_group_size = draft_model.config.num_attention_heads // \
|
||||||
|
draft_model.config.multi_query_group_num
|
||||||
|
|
||||||
tmp_matchness = 0
|
tmp_matchness = 0
|
||||||
e2e_tic = 0.0
|
e2e_tic = 0.0
|
||||||
|
|
@ -437,7 +507,7 @@ def speculative_generate(self,
|
||||||
# each iter cut off cur_len kv_cache from past_key_values1
|
# each iter cut off cur_len kv_cache from past_key_values1
|
||||||
draft_past_key_values = \
|
draft_past_key_values = \
|
||||||
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||||
past_key_values_storage)
|
past_key_values_storage, _enable_ipex)
|
||||||
original_draft_past_key_values = draft_past_key_values
|
original_draft_past_key_values = draft_past_key_values
|
||||||
else:
|
else:
|
||||||
draft_past_key_values = past_key_values
|
draft_past_key_values = past_key_values
|
||||||
|
|
@ -464,7 +534,10 @@ def speculative_generate(self,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
if self.config.model_type == "chatglm":
|
if self.config.model_type == "chatglm":
|
||||||
past_key_value_len = past_key_values[0][0].shape[0]
|
if _enable_ipex:
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[1]
|
||||||
|
else:
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[0]
|
||||||
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
||||||
forward_args["position_ids"] = position_ids
|
forward_args["position_ids"] = position_ids
|
||||||
elif self.config.model_type == "gptj":
|
elif self.config.model_type == "gptj":
|
||||||
|
|
@ -533,6 +606,16 @@ def speculative_generate(self,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
elif "chatglm" in self.config.model_type:
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
|
position_ids = torch.arange(drafted_input_ids.shape[1], dtype=torch.long,
|
||||||
|
device=drafted_input_ids.device).unsqueeze(0)
|
||||||
|
position_ids = position_ids.repeat(1, 1) + past_key_value_len
|
||||||
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
# return_last_logit=torch.tensor(False),
|
||||||
|
past_key_values=past_key_values,)
|
||||||
elif "mistral" in self.config.model_type:
|
elif "mistral" in self.config.model_type:
|
||||||
past_key_value_len = past_key_values[0][0].shape[2]
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
seq_len = drafted_input_ids.shape[1]
|
seq_len = drafted_input_ids.shape[1]
|
||||||
|
|
@ -591,7 +674,8 @@ def speculative_generate(self,
|
||||||
self.verify_time.append(toc - tic)
|
self.verify_time.append(toc - tic)
|
||||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||||
|
|
||||||
past_key_values = output['past_key_values']
|
if past_key_values is None:
|
||||||
|
past_key_values = output['past_key_values']
|
||||||
|
|
||||||
if generation_config.do_sample:
|
if generation_config.do_sample:
|
||||||
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
|
draft_tokens = drafted_input_ids[:, 1:].squeeze(0)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue