LLM: Enable qwen target_model ipex (#10232)

* change order

* enable qwen ipex

* update qwen example

* update

* fix style

* update
This commit is contained in:
Wang, Jian4 2024-02-26 16:41:12 +08:00 committed by GitHub
parent 3e6d188553
commit f9b75f900b
4 changed files with 74 additions and 6 deletions

View file

@ -89,4 +89,31 @@ assistant
Tokens generated 128
E2E Generation time x.xxxxs
First token latency x.xxxxs
```
```
### 4. Accelerate with BIGDL_OPT_IPEX
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git004cd72d](https://github.com/intel/intel-extension-for-pytorch/tree/004cd72db60e87bb0712d42e3120bac9854bd77e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
#### 4.1 Download IPEX installation script
```bash
# Depend on Conda and GCC 12.3
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/004cd72db60e87bb0712d42e3120bac9854bd77e/scripts/compile_bundle.sh
```
#### 4.2 Activate your conda environment
```bash
conda activate <your_conda_env>
```
#### 4.3 Set VER_IPEX in compile_bundle.sh to 004cd72db60e87bb0712d42e3120bac9854bd77e
```bash
sed -i 's/VER_IPEX=main/VER_IPEX=004cd72db60e87bb0712d42e3120bac9854bd77e/g' "compile_bundle.sh"
```
#### 4.4 Install IPEX and other dependencies
```bash
# Install IPEX 2.3.0+git004cd72d
bash compile_bundle.sh
# Update transformers
pip install transformers==4.36.2

View file

@ -54,6 +54,8 @@ if __name__ == '__main__':
help='Max tokens to predict')
parser.add_argument('--th_stop_draft', type=float, default=0.6,
help='draft stop probility')
parser.add_argument('--min_step_draft', type=int, default=1,
help='min tokens per step draft')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
@ -67,13 +69,15 @@ if __name__ == '__main__':
speculative=True,
trust_remote_code=True,
use_cache=True)
model = model.to('cpu')
#model = model.to('cpu')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
with torch.inference_mode():
prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask.to(model.device)
actual_in_len = input_ids.shape[1]
print("actual input_ids length:" + str(actual_in_len))
@ -81,6 +85,8 @@ if __name__ == '__main__':
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
th_stop_draft=args.th_stop_draft,
attention_mask=attention_mask,
min_step_draft=args.min_step_draft,
do_sample=False)
output_str = tokenizer.decode(output[0])
@ -89,6 +95,8 @@ if __name__ == '__main__':
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
th_stop_draft=args.th_stop_draft,
attention_mask=attention_mask,
min_step_draft=args.min_step_draft,
do_sample=False)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.perf_counter()

View file

@ -93,12 +93,10 @@ def _ipex_optimize_attention(model):
def _ipex_optimize_model(model, rms_classes):
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
_ipex_optimize_rmsnorm(model, rms_classes)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
model.register_forward_hook(output_hook, with_kwargs=True)
def _ipex_jit(model):
@ -124,6 +122,8 @@ def _ipex_jit(model):
model = _set_optimized_model_for_generation(
model, optimized_model=trace_model
)
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
model.register_forward_hook(output_hook, with_kwargs=True)
return model

View file

@ -153,6 +153,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
len0, len3).permute(2, 0, 1, 3)
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
ipex_past_key_values.append(list)
elif self.config.model_type == "qwen":
ipex_past_key_values = [
[pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :],
pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]]
for pkv in past_key_values
]
else:
ipex_past_key_values = [
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
@ -217,6 +223,18 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
torch.float32)
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "qwen":
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
k0 = k0.permute(0, 2, 1, 3)
v0 = v0.permute(0, 2, 1, 3)
past_key_values_storage.append((k0, v0))
past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to(
torch.float32)
past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
@ -309,6 +327,16 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
key.to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
value.to(torch.float32)
elif self.config.model_type == "qwen":
size = original_draft_past_key_values[0][0].size(1)
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3)
delta_past_value = \
past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3)
past_key_values_storage[i][0][:, size:size1, :, :] = \
delta_past_key.to(torch.float32)
past_key_values_storage[i][1][:, size:size1, :, :] = \
delta_past_value.to(torch.float32)
else:
delta_past_key = \
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
@ -444,9 +472,10 @@ def speculative_generate(self,
if not ((self.config.model_type == 'baichuan') or
('llama' in self.config.model_type) or
("mistral" in self.config.model_type) or
("qwen" in self.config.model_type) or
("chatglm" in self.config.model_type)):
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
Llama, Baichuan2, Mistral and ChatGLM models currently.")
Llama, Baichuan2, Mistral and ChatGLM and Qwen models currently.")
if "chatglm" in self.config.model_type:
global query_group_size
query_group_size = draft_model.config.num_attention_heads // \
@ -637,6 +666,10 @@ def speculative_generate(self,
position_ids=position_ids,
# return_last_logit=torch.tensor(False),
past_key_values=past_key_values,)
elif "qwen" in self.config.model_type:
output = self.trace_graph(input_ids=drafted_input_ids,
attention_mask=cur_attention_mask,
past_key_values=past_key_values)
elif "mistral" in self.config.model_type:
past_key_value_len = past_key_values[0][0].shape[2]
seq_len = drafted_input_ids.shape[1]