[LLM] Support CPU deepspeed distributed inference (#9259)

* [LLM] Support CPU Deepspeed distributed inference

* Update run_deepspeed.py

* Rename

* fix style

* add new codes

* refine

* remove annotated codes

* refine

* Update README.md

* refine doc and example code
This commit is contained in:
Heyang Sun 2023-11-06 17:56:42 +08:00 committed by GitHub
parent f9bf5382ff
commit af94058203
8 changed files with 346 additions and 4 deletions

View file

@ -17,4 +17,6 @@ test_api:
- "pytorch_autocast_bf16"
# - "ipex_fp16_gpu" # on Intel GPU
# - "transformer_int4_gpu" # on Intel GPU
# - "optimize_model_gpu" # on Intel GPU
# - "optimize_model_gpu" # on Intel GPU
# - "deepspeed_transformer_int4_cpu" # on Intel SPR Server

View file

@ -0,0 +1,18 @@
#!/bin/bash
source bigdl-nano-init
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
source /opt/intel/oneccl/env/setvars.sh
export WORLD_SIZE=2 # run 1 instance per SPR socket, thus 2 instances on 2 sockets, 96 cores
export MASTER_ADDR=127.0.0.1
export CCL_ZE_IPC_EXCHANGE=sockets
export DS_ACCELERATOR="cpu"
export CCL_WORKER_AFFINITY=auto
unset KMP_AFFINITY # deepspeed will set it for each instance automatically
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_PROCESS_LAUNCHER=none
deepspeed \
--bind_cores_to_rank \
--bind_core_list 0-95 \
run.py

View file

@ -55,6 +55,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_pytorch_autocast_bf16(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
elif test_api == 'ipex_fp16_gpu':
result = run_ipex_fp16_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams)
elif test_api == 'deepspeed_transformer_int4_cpu':
result = run_deepspeed_transformer_int4_cpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit)
for in_out_pair in in_out_pairs:
if result:
@ -540,6 +542,92 @@ def run_ipex_fp16_gpu(repo_id,
torch.xpu.empty_cache()
return result
def run_deepspeed_transformer_int4_cpu(repo_id,
local_model_hub,
in_out_pairs,
warm_up,
num_trials,
num_beams,
low_bit):
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer
import deepspeed
from bigdl.llm import optimize_model
import argparse
# parser is for deepspeed subprocesses' inline parameter
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--local_rank', type=str, default=0, help='this is automatically set when using deepspeed launcher')
args = parser.parse_args()
local_rank = int(os.getenv("RANK", "1"))
if local_rank == -1:
local_rank = args.local_rank
world_size = int(os.getenv("WORLD_SIZE", "1"))
model_path = get_model_path(repo_id, local_model_hub)
st = time.perf_counter()
# Note: only tested cpu Llama2-7b
# Native Huggingface transformers loading to enable deepspeed init
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,
use_cache=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Parallelize model on deepspeed
model = deepspeed.init_inference(model, mp_size=world_size,
dtype=torch.float16,
replace_method="auto")
# Apply BigDL-LLM INT4 optimization to enable BenchmarkWrapper
# Note: only tested sym_int4
model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit)
model = model.to(f'cpu:{local_rank}')
end = time.perf_counter()
print(">> loading of model costs {}s".format(end - st))
model = BenchmarkWrapper(model)
result = {}
with torch.inference_mode():
for in_out in in_out_pairs:
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
# As different tokenizer has different encodings,
# in_len.txt maybe shorter than we need,
# use much longer context to make sure input length
test_length = min(in_len*2, 8192)
while test_length not in [32, 256, 1024, 2048, 8192]:
test_length = test_length * 2
input_str = open(f"prompt/{test_length}.txt", 'r').read()
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = input_ids[:, :in_len]
true_str = tokenizer.batch_decode(input_ids)[0]
input_ids = tokenizer.encode(true_str, return_tensors="pt")
actual_in_len = input_ids.shape[1]
result[in_out] = []
for i in range(num_trials + warm_up):
st = time.perf_counter()
output_ids = model.generate(input_ids, do_sample=False, max_new_tokens=out_len,
num_beams=num_beams)
end = time.perf_counter()
if local_rank == 0:
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
if local_rank == 0:
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up :
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
actual_in_len, actual_out_len])
return result
if __name__ == '__main__':
from omegaconf import OmegaConf

View file

@ -0,0 +1,69 @@
### Run Tensor-Parallel BigDL Transformers INT4 Inference with Deepspeed
#### 1. Install Dependencies
Install necessary packages (here Python 3.9 is our test environment):
```bash
bash install.sh
```
#### 2. Initialize Deepspeed Distributed Context
Like shown in example code `deepspeed_autotp.py`, you can construct parallel model with Python API:
```python
# Load in HuggingFace Transformers' model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(...)
# Parallelize model on deepspeed
import deepspeed
model = deepspeed.init_inference(
model, # an AutoModel of Transformers
mp_size = world_size, # instance (process) count
dtype=torch.float16,
replace_method="auto")
```
Then, returned model is converted into a deepspeed InferenceEnginee type.
#### 3. Optimize Model with BigDL-LLM Low Bit
Distributed model managed by deepspeed can be further optimized with BigDL low-bit Python API, e.g. sym_int4:
```python
# Apply BigDL-LLM INT4 optimizations on transformers
from bigdl.llm import optimize_model
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4')
model = model.to(f'cpu:{local_rank}') # move partial model to local rank
```
Then, a bigdl-llm transformers is returned, which in the following, can serve in parallel with native APIs.
#### 4. Start Python Code
You can try deepspeed with BigDL LLM by:
```bash
bash run.sh
```
If you want to run your own application, there are **necessary configurations in the script** which can also be ported to run your custom deepspeed application:
```bash
# run.sh
source bigdl-nano-init
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
source /opt/intel/oneccl/env/setvars.sh
......
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_PROCESS_LAUNCHER=none
```
Set the above configurations before running `deepspeed` please to ensure right parallel communication and high performance.

View file

@ -0,0 +1,125 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/TimDettmers/bitsandbytes/blob/0.39.1/bitsandbytes/nn/modules.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer
import deepspeed
from bigdl.llm import optimize_model
import torch
import intel_extension_for_pytorch as ipex
import time
import argparse
from benchmark_util import BenchmarkWrapper
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--local_rank', type=int, default=0, help='this is automatically set when using deepspeed launcher')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
world_size = int(os.getenv("WORLD_SIZE", "1"))
local_rank = int(os.getenv("RANK", "-1")) # RANK is automatically set by CCL distributed backend
if local_rank == -1: # args.local_rank is automatically set by deepspeed subprocess command
local_rank = args.local_rank
# Native Huggingface transformers loading
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map={"": "cpu"},
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
use_cache=True
)
# Parallelize model on deepspeed
model = deepspeed.init_inference(
model,
mp_size = world_size,
dtype=torch.float16,
replace_method="auto"
)
# Apply BigDL-LLM INT4 optimizations on transformers
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4')
model = model.to(f'cpu:{local_rank}')
print(model)
model = BenchmarkWrapper(model, do_print=True)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
# Batch tokenizing
prompt = args.prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'cpu:{local_rank}')
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
use_cache=True)
# start inference
start = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM INT4 optimizations
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
end = time.time()
if local_rank == 0:
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print('-'*20, 'Output', '-'*20)
print(output_str)
print(f'Inference time: {end - start} s')

View file

@ -0,0 +1,9 @@
#!/bin/bash
# install torch
pip install torch==2.1.0
# install torchccl
pip install https://intel-extension-for-pytorch.s3.amazonaws.com/torch_ccl/cpu/oneccl_bind_pt-2.1.0%2Bcpu-cp39-cp39-linux_x86_64.whl
# install deepspeed
pip install deepspeed==0.11.1
# exclude intel deepspeed extension, which is only for XPU
pip uninstall intel-extension-for-deepspeed --ignore-missing

View file

@ -0,0 +1,18 @@
#/bin/bash
source bigdl-nano-init
unset OMP_NUM_THREADS # deepspeed will set it for each instance automatically
source /opt/intel/oneccl/env/setvars.sh
export WORLD_SIZE=2 # run 1 instance per SPR socket, thus 2 instances on 2 sockets, 96 cores
export MASTER_ADDR=127.0.0.1
export CCL_ZE_IPC_EXCHANGE=sockets
export DS_ACCELERATOR="cpu"
export CCL_WORKER_AFFINITY=auto
unset KMP_AFFINITY # deepspeed will set it for each instance automatically
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_PROCESS_LAUNCHER=none
deepspeed \
--bind_cores_to_rank \
--bind_core_list 0-95 \
deepspeed_autotp.py

View file

@ -464,17 +464,30 @@ class LowBitLinear(nn.Linear):
" supported on CPU")
if self.training and x.requires_grad:
result = MatMulLowBitCPU.apply(x, self.weight)
if self.bias is not None:
result = result + self.bias
else:
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
if self.mp_group is None:
# none-distributed mode
result = F.linear(x, x0_fp32, self.bias)
else:
result = F.linear(x, x0_fp32)
from deepspeed import comm as dist
# Parallel F.linear should be avoided,
# thus deepspeed allreduce after the operation
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
# bias is consistent among multi instances,
# deepspeed only allreduce result without bias to reduce comunication
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
return result