LLM: Enable qwen target_model ipex (#10232)
* change order * enable qwen ipex * update qwen example * update * fix style * update
This commit is contained in:
parent
3e6d188553
commit
f9b75f900b
4 changed files with 74 additions and 6 deletions
|
|
@ -90,3 +90,30 @@ Tokens generated 128
|
||||||
E2E Generation time x.xxxxs
|
E2E Generation time x.xxxxs
|
||||||
First token latency x.xxxxs
|
First token latency x.xxxxs
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 4. Accelerate with BIGDL_OPT_IPEX
|
||||||
|
|
||||||
|
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git004cd72d](https://github.com/intel/intel-extension-for-pytorch/tree/004cd72db60e87bb0712d42e3120bac9854bd77e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
|
||||||
|
|
||||||
|
#### 4.1 Download IPEX installation script
|
||||||
|
```bash
|
||||||
|
# Depend on Conda and GCC 12.3
|
||||||
|
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/004cd72db60e87bb0712d42e3120bac9854bd77e/scripts/compile_bundle.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.2 Activate your conda environment
|
||||||
|
```bash
|
||||||
|
conda activate <your_conda_env>
|
||||||
|
```
|
||||||
|
#### 4.3 Set VER_IPEX in compile_bundle.sh to 004cd72db60e87bb0712d42e3120bac9854bd77e
|
||||||
|
```bash
|
||||||
|
sed -i 's/VER_IPEX=main/VER_IPEX=004cd72db60e87bb0712d42e3120bac9854bd77e/g' "compile_bundle.sh"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.4 Install IPEX and other dependencies
|
||||||
|
```bash
|
||||||
|
# Install IPEX 2.3.0+git004cd72d
|
||||||
|
bash compile_bundle.sh
|
||||||
|
|
||||||
|
# Update transformers
|
||||||
|
pip install transformers==4.36.2
|
||||||
|
|
@ -54,6 +54,8 @@ if __name__ == '__main__':
|
||||||
help='Max tokens to predict')
|
help='Max tokens to predict')
|
||||||
parser.add_argument('--th_stop_draft', type=float, default=0.6,
|
parser.add_argument('--th_stop_draft', type=float, default=0.6,
|
||||||
help='draft stop probility')
|
help='draft stop probility')
|
||||||
|
parser.add_argument('--min_step_draft', type=int, default=1,
|
||||||
|
help='min tokens per step draft')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model_path = args.repo_id_or_model_path
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
@ -67,13 +69,15 @@ if __name__ == '__main__':
|
||||||
speculative=True,
|
speculative=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('cpu')
|
#model = model.to('cpu')
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt)
|
prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
input_ids = inputs.input_ids
|
||||||
|
attention_mask = inputs.attention_mask.to(model.device)
|
||||||
actual_in_len = input_ids.shape[1]
|
actual_in_len = input_ids.shape[1]
|
||||||
print("actual input_ids length:" + str(actual_in_len))
|
print("actual input_ids length:" + str(actual_in_len))
|
||||||
|
|
||||||
|
|
@ -81,6 +85,8 @@ if __name__ == '__main__':
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
th_stop_draft=args.th_stop_draft,
|
th_stop_draft=args.th_stop_draft,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
min_step_draft=args.min_step_draft,
|
||||||
do_sample=False)
|
do_sample=False)
|
||||||
output_str = tokenizer.decode(output[0])
|
output_str = tokenizer.decode(output[0])
|
||||||
|
|
||||||
|
|
@ -89,6 +95,8 @@ if __name__ == '__main__':
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
max_new_tokens=args.n_predict,
|
max_new_tokens=args.n_predict,
|
||||||
th_stop_draft=args.th_stop_draft,
|
th_stop_draft=args.th_stop_draft,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
min_step_draft=args.min_step_draft,
|
||||||
do_sample=False)
|
do_sample=False)
|
||||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
|
||||||
|
|
@ -93,12 +93,10 @@ def _ipex_optimize_attention(model):
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_model(model, rms_classes):
|
def _ipex_optimize_model(model, rms_classes):
|
||||||
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
|
||||||
|
|
||||||
_ipex_optimize_rmsnorm(model, rms_classes)
|
_ipex_optimize_rmsnorm(model, rms_classes)
|
||||||
_ipex_optimize_attention(model)
|
_ipex_optimize_attention(model)
|
||||||
_ipex_optimize_decoder(model)
|
_ipex_optimize_decoder(model)
|
||||||
model.register_forward_hook(output_hook, with_kwargs=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _ipex_jit(model):
|
def _ipex_jit(model):
|
||||||
|
|
@ -124,6 +122,8 @@ def _ipex_jit(model):
|
||||||
model = _set_optimized_model_for_generation(
|
model = _set_optimized_model_for_generation(
|
||||||
model, optimized_model=trace_model
|
model, optimized_model=trace_model
|
||||||
)
|
)
|
||||||
|
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
||||||
|
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
len0, len3).permute(2, 0, 1, 3)
|
len0, len3).permute(2, 0, 1, 3)
|
||||||
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
|
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
|
||||||
ipex_past_key_values.append(list)
|
ipex_past_key_values.append(list)
|
||||||
|
elif self.config.model_type == "qwen":
|
||||||
|
ipex_past_key_values = [
|
||||||
|
[pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :],
|
||||||
|
pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]]
|
||||||
|
for pkv in past_key_values
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
ipex_past_key_values = [
|
ipex_past_key_values = [
|
||||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||||
|
|
@ -217,6 +223,18 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
torch.float32)
|
torch.float32)
|
||||||
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
|
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
|
elif self.config.model_type == "qwen":
|
||||||
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
|
dtype=torch.float32)
|
||||||
|
k0 = k0.permute(0, 2, 1, 3)
|
||||||
|
v0 = v0.permute(0, 2, 1, 3)
|
||||||
|
past_key_values_storage.append((k0, v0))
|
||||||
|
past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to(
|
||||||
|
torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to(
|
||||||
|
torch.float32)
|
||||||
else:
|
else:
|
||||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
@ -309,6 +327,16 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
||||||
key.to(torch.float32)
|
key.to(torch.float32)
|
||||||
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||||
value.to(torch.float32)
|
value.to(torch.float32)
|
||||||
|
elif self.config.model_type == "qwen":
|
||||||
|
size = original_draft_past_key_values[0][0].size(1)
|
||||||
|
delta_past_key = \
|
||||||
|
past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3)
|
||||||
|
delta_past_value = \
|
||||||
|
past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3)
|
||||||
|
past_key_values_storage[i][0][:, size:size1, :, :] = \
|
||||||
|
delta_past_key.to(torch.float32)
|
||||||
|
past_key_values_storage[i][1][:, size:size1, :, :] = \
|
||||||
|
delta_past_value.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
delta_past_key = \
|
delta_past_key = \
|
||||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||||
|
|
@ -444,9 +472,10 @@ def speculative_generate(self,
|
||||||
if not ((self.config.model_type == 'baichuan') or
|
if not ((self.config.model_type == 'baichuan') or
|
||||||
('llama' in self.config.model_type) or
|
('llama' in self.config.model_type) or
|
||||||
("mistral" in self.config.model_type) or
|
("mistral" in self.config.model_type) or
|
||||||
|
("qwen" in self.config.model_type) or
|
||||||
("chatglm" in self.config.model_type)):
|
("chatglm" in self.config.model_type)):
|
||||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||||
Llama, Baichuan2, Mistral and ChatGLM models currently.")
|
Llama, Baichuan2, Mistral and ChatGLM and Qwen models currently.")
|
||||||
if "chatglm" in self.config.model_type:
|
if "chatglm" in self.config.model_type:
|
||||||
global query_group_size
|
global query_group_size
|
||||||
query_group_size = draft_model.config.num_attention_heads // \
|
query_group_size = draft_model.config.num_attention_heads // \
|
||||||
|
|
@ -637,6 +666,10 @@ def speculative_generate(self,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
# return_last_logit=torch.tensor(False),
|
# return_last_logit=torch.tensor(False),
|
||||||
past_key_values=past_key_values,)
|
past_key_values=past_key_values,)
|
||||||
|
elif "qwen" in self.config.model_type:
|
||||||
|
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||||
|
attention_mask=cur_attention_mask,
|
||||||
|
past_key_values=past_key_values)
|
||||||
elif "mistral" in self.config.model_type:
|
elif "mistral" in self.config.model_type:
|
||||||
past_key_value_len = past_key_values[0][0].shape[2]
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
seq_len = drafted_input_ids.shape[1]
|
seq_len = drafted_input_ids.shape[1]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue