diff --git a/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md b/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md index 2b8dcaea..e5895024 100644 --- a/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md @@ -63,33 +63,20 @@ First token latency x.xxxxs ### 4. Accelerate with BIGDL_OPT_IPEX -To accelerate speculative decoding on CPU, optionally, you can install our validated version of [IPEX 2.3.0+git0c63936](https://github.com/intel/intel-extension-for-pytorch/tree/0c63936d7a6740679987920367ae2e0cdb375b2e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.) +To accelerate speculative decoding on CPU, optionally, you can install our validated version of [IPEX 2.2.0+cpu](https://github.com/intel/intel-extension-for-pytorch/tree/v2.2.0%2Bcpu) refering to [IPEX's installation guide](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.2.0%2Bcpu), or by the following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.) -#### 4.1 Download IPEX installation script +#### 4.1 Install IPEX 2.2.0+cpu ```bash -# Depend on Conda and GCC 12.3 -wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/0c63936d7a6740679987920367ae2e0cdb375b2e/scripts/compile_bundle.sh -``` - -#### 4.2 Activate your conda environment -```bash -conda activate -``` -#### 4.3 Set VER_IPEX in compile_bundle.sh to 0c63936d7a6740679987920367ae2e0cdb375b2e -```bash -sed -i 's/VER_IPEX=main/VER_IPEX=0c63936d7a6740679987920367ae2e0cdb375b2e/g' "compile_bundle.sh" -``` - -#### 4.4 Install IPEX and other dependencies -```bash -# Install IPEX 2.3.0+git0c63936 -bash compile_bundle.sh +python -m pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cpu +python -m pip install intel-extension-for-pytorch==2.2.0 +python -m pip install oneccl_bind_pt==2.2.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ +# if there is any installation problem for oneccl_binding, you can also find suitable index url at "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" or "https://developer.intel.com/ipex-whl-stable-cpu" according to your environment. # Install other dependencies pip install -r requirements.txt ``` -#### 4.5 Run Baichuan2 Models with IPEX +#### 4.2 Run Baichuan2 Models with IPEX After installed IPEX, **if the size of your Baichuan2 is 7B**, replace `modeling_baichuan.py` file under your model directory with `./baichaun2_7b_opt_ipex/modeling_baichuan.ipex`, like: diff --git a/python/llm/example/CPU/Speculative-Decoding/llama2/README.md b/python/llm/example/CPU/Speculative-Decoding/llama2/README.md index 71dfddc0..f05ac46b 100644 --- a/python/llm/example/CPU/Speculative-Decoding/llama2/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/llama2/README.md @@ -98,27 +98,14 @@ First token latency xx.xxxxs ### 4. Accelerate with BIGDL_OPT_IPEX -To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git0c63936](https://github.com/intel/intel-extension-for-pytorch/tree/0c63936d7a6740679987920367ae2e0cdb375b2e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.) +To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.2.0+cpu](https://github.com/intel/intel-extension-for-pytorch/tree/v2.2.0%2Bcpu) refering to [IPEX's installation guide](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.2.0%2Bcpu), or by the following commands: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.) -#### 4.1 Download IPEX installation script ```bash -# Depend on Conda and GCC 12.3 -wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/0c63936d7a6740679987920367ae2e0cdb375b2e/scripts/compile_bundle.sh -``` - -#### 4.2 Activate your conda environment -```bash -conda activate -``` -#### 4.3 Set VER_IPEX in compile_bundle.sh to 0c63936d7a6740679987920367ae2e0cdb375b2e -```bash -sed -i 's/VER_IPEX=main/VER_IPEX=0c63936d7a6740679987920367ae2e0cdb375b2e/g' "compile_bundle.sh" -``` - -#### 4.4 Install IPEX and other dependencies -```bash -# Install IPEX 2.3.0+git0c63936 -bash compile_bundle.sh +# Install IPEX 2.2.0+cpu +python -m pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cpu +python -m pip install intel-extension-for-pytorch==2.2.0 +python -m pip install oneccl_bind_pt==2.2.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ +# if there is any installation problem for oneccl_binding, you can also find suitable index url at "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" or "https://developer.intel.com/ipex-whl-stable-cpu" according to your environment. # Update transformers pip install transformers==4.36.2 diff --git a/python/llm/example/CPU/Speculative-Decoding/mistral/README.md b/python/llm/example/CPU/Speculative-Decoding/mistral/README.md index 47b31195..4556159b 100644 --- a/python/llm/example/CPU/Speculative-Decoding/mistral/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/mistral/README.md @@ -84,4 +84,4 @@ First token latency xx.xxxxs ### 4. Accelerate with BIGDL_OPT_IPEX -BIGDL_OPT_IPEX can help to accelerate speculative decoding on Mistral, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md#4-accelerate-with-bigdl_opt_ipex) for a try. +BIGDL_OPT_IPEX can help to accelerate speculative decoding on Mistral, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/llama2#4-accelerate-with-bigdl_opt_ipex) for a try. diff --git a/python/llm/example/CPU/Speculative-Decoding/vicuna/README.md b/python/llm/example/CPU/Speculative-Decoding/vicuna/README.md index 6efc11bd..b648040f 100644 --- a/python/llm/example/CPU/Speculative-Decoding/vicuna/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/vicuna/README.md @@ -94,4 +94,4 @@ First token latency xx.xxxxs ### 4. Accelerate with BIGDL_OPT_IPEX -BIGDL_OPT_IPEX can help to accelerate speculative decoding to some extend, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md#4-accelerate-with-bigdl_opt_ipex) for a try. +BIGDL_OPT_IPEX can help to accelerate speculative decoding to some extend, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/llama2#4-accelerate-with-bigdl_opt_ipex) for a try. diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index a5fd8813..824b02ba 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -625,27 +625,38 @@ def replace_func(m, target_m, func_name, new_func): def _optimize_ipex(model): import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.transformers.optimize import model_convert_reference - from intel_extension_for_pytorch.transformers.models.reference.models import output_hook from transformers.modeling_attn_mask_utils import AttentionMaskConverter from bigdl.llm.transformers.convert_ipex import ( - _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, - _ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks + _ipex_optimize_model, _ipex_jit, _make_causal_mask, + _llama_model_forward_4_35, convert_function, GLM_get_masks ) - AttentionMaskConverter._make_causal_mask = _make_causal_mask - convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, - _llama_model_forward_4_35) model = model_convert_reference(model) - if model.config.architectures is not None \ - and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: + + rms_classes = [ + transformers.models.llama.modeling_llama.LlamaRMSNorm, + ] + if 'llama' in model.config.model_type: + AttentionMaskConverter._make_causal_mask = _make_causal_mask + convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, + _llama_model_forward_4_35) + elif "mistral" in model.config.model_type: + AttentionMaskConverter._make_causal_mask = _make_causal_mask + convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, + _llama_model_forward_4_35) + elif model.config.architectures is not None \ + and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: # noqa + # for chatglm3-6B + rms_classes.append( + type(model.transformer.encoder.layers[0].input_layernorm) + ) convert_function(model.transformer, "get_masks", GLM_get_masks) + elif model.config.model_type == 'baichuan' and model.config.vocab_size == 125696: + # baichuan2 + rms_classes.append(type(model.model.layers[0].input_layernorm)) + model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() - _ipex_optimize_rmsnorm(model) - _ipex_optimize_attention(model) - _ipex_optimize_decoder(model) - - model.register_forward_hook(output_hook, with_kwargs=True) - + _ipex_optimize_model(model, rms_classes) return _ipex_jit(model) diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index 5ece6835..467da52b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -37,59 +37,14 @@ import torch from bigdl.llm.utils.common import invalidInputError from typing import List, Optional, Tuple, Union +from intel_extension_for_pytorch.transformers.optimize import ( + lowering_class_cpu, + convert_class, +) -def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False): - for name, sub_m in m.named_children(): - if isinstance(sub_m, target_m): - new_m = new_class(sub_m, config, tpp, woq) - setattr(m, name, new_m) - lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq) - - -def convert_class(m, target_m, new_class, config, distributed=False): - for name, sub_m in m.named_children(): - if isinstance(sub_m, target_m): - new_m = new_class(sub_m, config, distributed) - setattr(m, name, new_m) - convert_class(sub_m, target_m, new_class, config, distributed) - - -def _set_optimized_model_for_generation( - model, - optimized_model, - first_token_optimized_model=None, -): - from intel_extension_for_pytorch.transformers.models.reference.models import ( - IPEX_LLM_Model_Return - ) - if first_token_optimized_model is not None: - model.trace_graph_first = IPEX_LLM_Model_Return( - model, first_token_optimized_model - ).forward - - model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward - print( - "ipex.llm.optimize has set the optimized or quantization model for model.generate()" - ) - return model - - -def _ipex_optimize_rmsnorm(_model): +def _ipex_optimize_rmsnorm(_model, supported_classes): from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm - import transformers - supported_classes = [ - transformers.models.llama.modeling_llama.LlamaRMSNorm, - ] - if _model.config.architectures[0] == "BaichuanForCausalLM": - supported_classes.append(type(_model.model.layers[0].input_layernorm)) - if ( - _model.config.architectures[0] == "ChatGLMModel" - and _model.config.rmsnorm - ): - supported_classes.append( - type(_model.transformer.encoder.layers[0].input_layernorm) - ) for supported_class in supported_classes: lowering_class_cpu( _model, @@ -137,8 +92,20 @@ def _ipex_optimize_attention(model): ) +def _ipex_optimize_model(model, rms_classes): + from intel_extension_for_pytorch.transformers.models.reference.models import output_hook + + _ipex_optimize_rmsnorm(model, rms_classes) + _ipex_optimize_attention(model) + _ipex_optimize_decoder(model) + model.register_forward_hook(output_hook, with_kwargs=True) + + def _ipex_jit(model): - from intel_extension_for_pytorch.transformers.optimize import get_dummy_input + from intel_extension_for_pytorch.transformers.optimize import ( + get_dummy_input, + _set_optimized_model_for_generation + ) sample_inputs = ( get_dummy_input(model, return_dict=True) ) @@ -158,7 +125,7 @@ def _ipex_jit(model): model, optimized_model=trace_model ) - return model.eval() + return model def convert_function(m, func_name, new_function):