LLM: Enable qwen target_model ipex (#10232)
* change order * enable qwen ipex * update qwen example * update * fix style * update
This commit is contained in:
parent
3e6d188553
commit
f9b75f900b
4 changed files with 74 additions and 6 deletions
|
|
@ -89,4 +89,31 @@ assistant
|
|||
Tokens generated 128
|
||||
E2E Generation time x.xxxxs
|
||||
First token latency x.xxxxs
|
||||
```
|
||||
```
|
||||
|
||||
### 4. Accelerate with BIGDL_OPT_IPEX
|
||||
|
||||
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git004cd72d](https://github.com/intel/intel-extension-for-pytorch/tree/004cd72db60e87bb0712d42e3120bac9854bd77e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
|
||||
|
||||
#### 4.1 Download IPEX installation script
|
||||
```bash
|
||||
# Depend on Conda and GCC 12.3
|
||||
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/004cd72db60e87bb0712d42e3120bac9854bd77e/scripts/compile_bundle.sh
|
||||
```
|
||||
|
||||
#### 4.2 Activate your conda environment
|
||||
```bash
|
||||
conda activate <your_conda_env>
|
||||
```
|
||||
#### 4.3 Set VER_IPEX in compile_bundle.sh to 004cd72db60e87bb0712d42e3120bac9854bd77e
|
||||
```bash
|
||||
sed -i 's/VER_IPEX=main/VER_IPEX=004cd72db60e87bb0712d42e3120bac9854bd77e/g' "compile_bundle.sh"
|
||||
```
|
||||
|
||||
#### 4.4 Install IPEX and other dependencies
|
||||
```bash
|
||||
# Install IPEX 2.3.0+git004cd72d
|
||||
bash compile_bundle.sh
|
||||
|
||||
# Update transformers
|
||||
pip install transformers==4.36.2
|
||||
|
|
@ -54,6 +54,8 @@ if __name__ == '__main__':
|
|||
help='Max tokens to predict')
|
||||
parser.add_argument('--th_stop_draft', type=float, default=0.6,
|
||||
help='draft stop probility')
|
||||
parser.add_argument('--min_step_draft', type=int, default=1,
|
||||
help='min tokens per step draft')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
|
@ -67,13 +69,15 @@ if __name__ == '__main__':
|
|||
speculative=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('cpu')
|
||||
#model = model.to('cpu')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
input_ids = inputs.input_ids
|
||||
attention_mask = inputs.attention_mask.to(model.device)
|
||||
actual_in_len = input_ids.shape[1]
|
||||
print("actual input_ids length:" + str(actual_in_len))
|
||||
|
||||
|
|
@ -81,6 +85,8 @@ if __name__ == '__main__':
|
|||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
th_stop_draft=args.th_stop_draft,
|
||||
attention_mask=attention_mask,
|
||||
min_step_draft=args.min_step_draft,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0])
|
||||
|
||||
|
|
@ -89,6 +95,8 @@ if __name__ == '__main__':
|
|||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict,
|
||||
th_stop_draft=args.th_stop_draft,
|
||||
attention_mask=attention_mask,
|
||||
min_step_draft=args.min_step_draft,
|
||||
do_sample=False)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -93,12 +93,10 @@ def _ipex_optimize_attention(model):
|
|||
|
||||
|
||||
def _ipex_optimize_model(model, rms_classes):
|
||||
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
||||
|
||||
_ipex_optimize_rmsnorm(model, rms_classes)
|
||||
_ipex_optimize_attention(model)
|
||||
_ipex_optimize_decoder(model)
|
||||
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||
|
||||
|
||||
def _ipex_jit(model):
|
||||
|
|
@ -124,6 +122,8 @@ def _ipex_jit(model):
|
|||
model = _set_optimized_model_for_generation(
|
||||
model, optimized_model=trace_model
|
||||
)
|
||||
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
|
||||
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -153,6 +153,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
|||
len0, len3).permute(2, 0, 1, 3)
|
||||
list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]]
|
||||
ipex_past_key_values.append(list)
|
||||
elif self.config.model_type == "qwen":
|
||||
ipex_past_key_values = [
|
||||
[pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :],
|
||||
pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]]
|
||||
for pkv in past_key_values
|
||||
]
|
||||
else:
|
||||
ipex_past_key_values = [
|
||||
[pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :],
|
||||
|
|
@ -217,6 +223,18 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
|
|||
torch.float32)
|
||||
past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
elif self.config.model_type == "qwen":
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
k0 = k0.permute(0, 2, 1, 3)
|
||||
v0 = v0.permute(0, 2, 1, 3)
|
||||
past_key_values_storage.append((k0, v0))
|
||||
past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to(
|
||||
torch.float32)
|
||||
past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to(
|
||||
torch.float32)
|
||||
else:
|
||||
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
|
||||
dtype=torch.float32)
|
||||
|
|
@ -309,6 +327,16 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
|||
key.to(torch.float32)
|
||||
past_key_values_storage[i][1][size:size1, :, :, :] = \
|
||||
value.to(torch.float32)
|
||||
elif self.config.model_type == "qwen":
|
||||
size = original_draft_past_key_values[0][0].size(1)
|
||||
delta_past_key = \
|
||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3)
|
||||
delta_past_value = \
|
||||
past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3)
|
||||
past_key_values_storage[i][0][:, size:size1, :, :] = \
|
||||
delta_past_key.to(torch.float32)
|
||||
past_key_values_storage[i][1][:, size:size1, :, :] = \
|
||||
delta_past_value.to(torch.float32)
|
||||
else:
|
||||
delta_past_key = \
|
||||
past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3)
|
||||
|
|
@ -444,9 +472,10 @@ def speculative_generate(self,
|
|||
if not ((self.config.model_type == 'baichuan') or
|
||||
('llama' in self.config.model_type) or
|
||||
("mistral" in self.config.model_type) or
|
||||
("qwen" in self.config.model_type) or
|
||||
("chatglm" in self.config.model_type)):
|
||||
invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \
|
||||
Llama, Baichuan2, Mistral and ChatGLM models currently.")
|
||||
Llama, Baichuan2, Mistral and ChatGLM and Qwen models currently.")
|
||||
if "chatglm" in self.config.model_type:
|
||||
global query_group_size
|
||||
query_group_size = draft_model.config.num_attention_heads // \
|
||||
|
|
@ -637,6 +666,10 @@ def speculative_generate(self,
|
|||
position_ids=position_ids,
|
||||
# return_last_logit=torch.tensor(False),
|
||||
past_key_values=past_key_values,)
|
||||
elif "qwen" in self.config.model_type:
|
||||
output = self.trace_graph(input_ids=drafted_input_ids,
|
||||
attention_mask=cur_attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
elif "mistral" in self.config.model_type:
|
||||
past_key_value_len = past_key_values[0][0].shape[2]
|
||||
seq_len = drafted_input_ids.shape[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue