LLM: Update IPEX to 2.2.0+cpu and Refactor for _ipex_optimize (#10189)

Update IPEX to 2.2.0+cpu and refactor for _ipex_optimize.
This commit is contained in:
Xiangyu Tian 2024-02-22 16:01:11 +08:00 committed by GitHub
parent c876d9b5ca
commit f445217d02
6 changed files with 59 additions and 107 deletions

View file

@ -63,33 +63,20 @@ First token latency x.xxxxs
### 4. Accelerate with BIGDL_OPT_IPEX
To accelerate speculative decoding on CPU, optionally, you can install our validated version of [IPEX 2.3.0+git0c63936](https://github.com/intel/intel-extension-for-pytorch/tree/0c63936d7a6740679987920367ae2e0cdb375b2e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
To accelerate speculative decoding on CPU, optionally, you can install our validated version of [IPEX 2.2.0+cpu](https://github.com/intel/intel-extension-for-pytorch/tree/v2.2.0%2Bcpu) refering to [IPEX's installation guide](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.2.0%2Bcpu), or by the following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
#### 4.1 Download IPEX installation script
#### 4.1 Install IPEX 2.2.0+cpu
```bash
# Depend on Conda and GCC 12.3
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/0c63936d7a6740679987920367ae2e0cdb375b2e/scripts/compile_bundle.sh
```
#### 4.2 Activate your conda environment
```bash
conda activate <your_conda_env>
```
#### 4.3 Set VER_IPEX in compile_bundle.sh to 0c63936d7a6740679987920367ae2e0cdb375b2e
```bash
sed -i 's/VER_IPEX=main/VER_IPEX=0c63936d7a6740679987920367ae2e0cdb375b2e/g' "compile_bundle.sh"
```
#### 4.4 Install IPEX and other dependencies
```bash
# Install IPEX 2.3.0+git0c63936
bash compile_bundle.sh
python -m pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install intel-extension-for-pytorch==2.2.0
python -m pip install oneccl_bind_pt==2.2.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
# if there is any installation problem for oneccl_binding, you can also find suitable index url at "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" or "https://developer.intel.com/ipex-whl-stable-cpu" according to your environment.
# Install other dependencies
pip install -r requirements.txt
```
#### 4.5 Run Baichuan2 Models with IPEX
#### 4.2 Run Baichuan2 Models with IPEX
After installed IPEX, **if the size of your Baichuan2 is 7B**, replace `modeling_baichuan.py` file under your model directory with `./baichaun2_7b_opt_ipex/modeling_baichuan.ipex`, like:

View file

@ -98,27 +98,14 @@ First token latency xx.xxxxs
### 4. Accelerate with BIGDL_OPT_IPEX
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.3.0+git0c63936](https://github.com/intel/intel-extension-for-pytorch/tree/0c63936d7a6740679987920367ae2e0cdb375b2e) by following steps: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
To accelerate speculative decoding on CPU, you can install our validated version of [IPEX 2.2.0+cpu](https://github.com/intel/intel-extension-for-pytorch/tree/v2.2.0%2Bcpu) refering to [IPEX's installation guide](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.2.0%2Bcpu), or by the following commands: (Other versions of IPEX may have some conflicts and can not accelerate speculative decoding correctly.)
#### 4.1 Download IPEX installation script
```bash
# Depend on Conda and GCC 12.3
wget https://raw.githubusercontent.com/intel/intel-extension-for-pytorch/0c63936d7a6740679987920367ae2e0cdb375b2e/scripts/compile_bundle.sh
```
#### 4.2 Activate your conda environment
```bash
conda activate <your_conda_env>
```
#### 4.3 Set VER_IPEX in compile_bundle.sh to 0c63936d7a6740679987920367ae2e0cdb375b2e
```bash
sed -i 's/VER_IPEX=main/VER_IPEX=0c63936d7a6740679987920367ae2e0cdb375b2e/g' "compile_bundle.sh"
```
#### 4.4 Install IPEX and other dependencies
```bash
# Install IPEX 2.3.0+git0c63936
bash compile_bundle.sh
# Install IPEX 2.2.0+cpu
python -m pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install intel-extension-for-pytorch==2.2.0
python -m pip install oneccl_bind_pt==2.2.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
# if there is any installation problem for oneccl_binding, you can also find suitable index url at "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" or "https://developer.intel.com/ipex-whl-stable-cpu" according to your environment.
# Update transformers
pip install transformers==4.36.2

View file

@ -84,4 +84,4 @@ First token latency xx.xxxxs
### 4. Accelerate with BIGDL_OPT_IPEX
BIGDL_OPT_IPEX can help to accelerate speculative decoding on Mistral, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md#4-accelerate-with-bigdl_opt_ipex) for a try.
BIGDL_OPT_IPEX can help to accelerate speculative decoding on Mistral, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/llama2#4-accelerate-with-bigdl_opt_ipex) for a try.

View file

@ -94,4 +94,4 @@ First token latency xx.xxxxs
### 4. Accelerate with BIGDL_OPT_IPEX
BIGDL_OPT_IPEX can help to accelerate speculative decoding to some extend, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/baichuan2/README.md#4-accelerate-with-bigdl_opt_ipex) for a try.
BIGDL_OPT_IPEX can help to accelerate speculative decoding to some extend, and please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/Speculative-Decoding/llama2#4-accelerate-with-bigdl_opt_ipex) for a try.

View file

@ -625,27 +625,38 @@ def replace_func(m, target_m, func_name, new_func):
def _optimize_ipex(model):
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_ipex_optimize_rmsnorm, _llama_model_forward_4_35, convert_function, GLM_get_masks
_ipex_optimize_model, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35, convert_function, GLM_get_masks
)
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
model = model_convert_reference(model)
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
rms_classes = [
transformers.models.llama.modeling_llama.LlamaRMSNorm,
]
if 'llama' in model.config.model_type:
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
elif "mistral" in model.config.model_type:
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
elif model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: # noqa
# for chatglm3-6B
rms_classes.append(
type(model.transformer.encoder.layers[0].input_layernorm)
)
convert_function(model.transformer, "get_masks", GLM_get_masks)
elif model.config.model_type == 'baichuan' and model.config.vocab_size == 125696:
# baichuan2
rms_classes.append(type(model.model.layers[0].input_layernorm))
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_rmsnorm(model)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
model.register_forward_hook(output_hook, with_kwargs=True)
_ipex_optimize_model(model, rms_classes)
return _ipex_jit(model)

View file

@ -37,59 +37,14 @@
import torch
from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
from intel_extension_for_pytorch.transformers.optimize import (
lowering_class_cpu,
convert_class,
)
def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, tpp, woq)
setattr(m, name, new_m)
lowering_class_cpu(sub_m, target_m, new_class, config, tpp, woq)
def convert_class(m, target_m, new_class, config, distributed=False):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)
def _set_optimized_model_for_generation(
model,
optimized_model,
first_token_optimized_model=None,
):
from intel_extension_for_pytorch.transformers.models.reference.models import (
IPEX_LLM_Model_Return
)
if first_token_optimized_model is not None:
model.trace_graph_first = IPEX_LLM_Model_Return(
model, first_token_optimized_model
).forward
model.trace_graph = IPEX_LLM_Model_Return(model, optimized_model).forward
print(
"ipex.llm.optimize has set the optimized or quantization model for model.generate()"
)
return model
def _ipex_optimize_rmsnorm(_model):
def _ipex_optimize_rmsnorm(_model, supported_classes):
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
import transformers
supported_classes = [
transformers.models.llama.modeling_llama.LlamaRMSNorm,
]
if _model.config.architectures[0] == "BaichuanForCausalLM":
supported_classes.append(type(_model.model.layers[0].input_layernorm))
if (
_model.config.architectures[0] == "ChatGLMModel"
and _model.config.rmsnorm
):
supported_classes.append(
type(_model.transformer.encoder.layers[0].input_layernorm)
)
for supported_class in supported_classes:
lowering_class_cpu(
_model,
@ -137,8 +92,20 @@ def _ipex_optimize_attention(model):
)
def _ipex_optimize_model(model, rms_classes):
from intel_extension_for_pytorch.transformers.models.reference.models import output_hook
_ipex_optimize_rmsnorm(model, rms_classes)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
model.register_forward_hook(output_hook, with_kwargs=True)
def _ipex_jit(model):
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from intel_extension_for_pytorch.transformers.optimize import (
get_dummy_input,
_set_optimized_model_for_generation
)
sample_inputs = (
get_dummy_input(model, return_dict=True)
)
@ -158,7 +125,7 @@ def _ipex_jit(model):
model, optimized_model=trace_model
)
return model.eval()
return model
def convert_function(m, func_name, new_function):