diff --git a/python/llm/example/CPU/Speculative-Decoding/qwen/README.md b/python/llm/example/CPU/Speculative-Decoding/qwen/README.md index 4c6f1d32..85548b55 100644 --- a/python/llm/example/CPU/Speculative-Decoding/qwen/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/qwen/README.md @@ -89,4 +89,31 @@ assistant Tokens generated 128 E2E Generation time x.xxxxs First token latency x.xxxxs -``` \ No newline at end of file +``` + +### 4. Accelerate with BIGDL_OPT_IPEX + +To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git004cd72d](https://github.com/intel/intel-extension-for-pytorch/tree/004cd72db60e87bb0712d42e3120bac9854bd77e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.) + +#### 4.1 Download IPEX installation script +```bash +# Depend on Conda and GCC 12.3 +wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/004cd72db60e87bb0712d42e3120bac9854bd77e/scripts/compile_bundle.sh +``` + +#### 4.2 Activate your conda environment +```bash +conda activate +``` +#### 4.3 Set VER_IPEX in compile_bundle.sh to 004cd72db60e87bb0712d42e3120bac9854bd77e +```bash +sed -i 's/VER_IPEX=main/VER_IPEX=004cd72db60e87bb0712d42e3120bac9854bd77e/g' "compile_bundle.sh" +``` + +#### 4.4 Install IPEX and other dependencies +```bash +# Install IPEX 2.3.0+git004cd72d +bash compile_bundle.sh + +# Update transformers +pip install transformers==4.36.2 \ No newline at end of file diff --git a/python/llm/example/CPU/Speculative-Decoding/qwen/speculative.py b/python/llm/example/CPU/Speculative-Decoding/qwen/speculative.py index 87d15688..d3e71739 100644 --- a/python/llm/example/CPU/Speculative-Decoding/qwen/speculative.py +++ b/python/llm/example/CPU/Speculative-Decoding/qwen/speculative.py @@ -54,6 +54,8 @@ if __name__ == '__main__': help='Max tokens to predict') parser.add_argument('--th_stop_draft', type=float, default=0.6, help='draft stop probility') + parser.add_argument('--min_step_draft', type=int, default=1, + help='min tokens per step draft') args = parser.parse_args() model_path = args.repo_id_or_model_path @@ -67,13 +69,15 @@ if __name__ == '__main__': speculative=True, trust_remote_code=True, use_cache=True) - model = model.to('cpu') + #model = model.to('cpu') tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) with torch.inference_mode(): prompt = QWEN_PROMPT_FORMAT.format(prompt=args.prompt) - input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs.input_ids + attention_mask = inputs.attention_mask.to(model.device) actual_in_len = input_ids.shape[1] print("actual input_ids length:" + str(actual_in_len)) @@ -81,6 +85,8 @@ if __name__ == '__main__': output = model.generate(input_ids, max_new_tokens=args.n_predict, th_stop_draft=args.th_stop_draft, + attention_mask=attention_mask, + min_step_draft=args.min_step_draft, do_sample=False) output_str = tokenizer.decode(output[0]) @@ -89,6 +95,8 @@ if __name__ == '__main__': output = model.generate(input_ids, max_new_tokens=args.n_predict, th_stop_draft=args.th_stop_draft, + attention_mask=attention_mask, + min_step_draft=args.min_step_draft, do_sample=False) output_str = tokenizer.decode(output[0], skip_special_tokens=True) end = time.perf_counter() diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index 467da52b..067c3698 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -93,12 +93,10 @@ def _ipex_optimize_attention(model): def _ipex_optimize_model(model, rms_classes): - from intel_extension_for_pytorch.transformers.models.reference.models import output_hook _ipex_optimize_rmsnorm(model, rms_classes) _ipex_optimize_attention(model) _ipex_optimize_decoder(model) - model.register_forward_hook(output_hook, with_kwargs=True) def _ipex_jit(model): @@ -124,6 +122,8 @@ def _ipex_jit(model): model = _set_optimized_model_for_generation( model, optimized_model=trace_model ) + from intel_extension_for_pytorch.transformers.models.reference.models import output_hook + model.register_forward_hook(output_hook, with_kwargs=True) return model diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index aa9324f1..be77fdf8 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -153,6 +153,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values, len0, len3).permute(2, 0, 1, 3) list = [key[:cur_len, :, :, :], value[:cur_len, :, :, :]] ipex_past_key_values.append(list) + elif self.config.model_type == "qwen": + ipex_past_key_values = [ + [pkv[1].permute(1, 0, 2, 3)[:, :cur_len, :, :], + pkv[2].permute(1, 0, 2, 3)[:, :cur_len, :, :]] + for pkv in past_key_values + ] else: ipex_past_key_values = [ [pkv[1].permute(1, 2, 0, 3)[:, :, :cur_len, :], @@ -217,6 +223,18 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values, torch.float32) past_key_values_storage[i][1][:len2, :, :, :] = ipex_past_key_values[i][1].to( torch.float32) + elif self.config.model_type == "qwen": + k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + v0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, + dtype=torch.float32) + k0 = k0.permute(0, 2, 1, 3) + v0 = v0.permute(0, 2, 1, 3) + past_key_values_storage.append((k0, v0)) + past_key_values_storage[i][0][:, :len2, :, :] = ipex_past_key_values[i][0].to( + torch.float32) + past_key_values_storage[i][1][:, :len2, :, :] = ipex_past_key_values[i][1].to( + torch.float32) else: k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, dtype=torch.float32) @@ -309,6 +327,16 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s key.to(torch.float32) past_key_values_storage[i][1][size:size1, :, :, :] = \ value.to(torch.float32) + elif self.config.model_type == "qwen": + size = original_draft_past_key_values[0][0].size(1) + delta_past_key = \ + past_key_values[i][1][size:size1, :, :, :].permute(1, 0, 2, 3) + delta_past_value = \ + past_key_values[i][2][size:size1, :, :, :].permute(1, 0, 2, 3) + past_key_values_storage[i][0][:, size:size1, :, :] = \ + delta_past_key.to(torch.float32) + past_key_values_storage[i][1][:, size:size1, :, :] = \ + delta_past_value.to(torch.float32) else: delta_past_key = \ past_key_values[i][1][size:size1, :, :, :].permute(1, 2, 0, 3) @@ -444,9 +472,10 @@ def speculative_generate(self, if not ((self.config.model_type == 'baichuan') or ('llama' in self.config.model_type) or ("mistral" in self.config.model_type) or + ("qwen" in self.config.model_type) or ("chatglm" in self.config.model_type)): invalidInputError(False, "BigDL Speculative Decoding with IPEX BF16 only supports \ - Llama, Baichuan2, Mistral and ChatGLM models currently.") + Llama, Baichuan2, Mistral and ChatGLM and Qwen models currently.") if "chatglm" in self.config.model_type: global query_group_size query_group_size = draft_model.config.num_attention_heads // \ @@ -637,6 +666,10 @@ def speculative_generate(self, position_ids=position_ids, # return_last_logit=torch.tensor(False), past_key_values=past_key_values,) + elif "qwen" in self.config.model_type: + output = self.trace_graph(input_ids=drafted_input_ids, + attention_mask=cur_attention_mask, + past_key_values=past_key_values) elif "mistral" in self.config.model_type: past_key_value_len = past_key_values[0][0].shape[2] seq_len = drafted_input_ids.shape[1]