Support deepspeed AutoTP (#9230)

* Support deepspeed

* add test script

* refactor convert

* refine example

* refine

* refine example

* fix style

* refine example and adapte latest ipex

* fix style
This commit is contained in:
Yang Wang 2023-10-25 14:46:28 +08:00 committed by GitHub
parent a6a8afc47e
commit 067c7e8098
5 changed files with 190 additions and 9 deletions

View file

@ -0,0 +1,103 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
import transformers
import deepspeed
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
from bigdl.llm import optimize_model
import torch
import intel_extension_for_pytorch as ipex
import time
import argparse
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
from transformers import LlamaTokenizer, AutoTokenizer
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
use_cache=True)
model = deepspeed.init_inference(
model,
mp_size=world_size,
dtype=torch.float16,
replace_method="auto",
)
# move model to cpu and use bigdl-llm `optimize_model` to convert the
# model into optimized low bit format
# convert the rest of the model into float16 to reduce allreduce traffic
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
# move model back to xpu
model = model.to(f'xpu:{local_rank}')
print(model)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
# prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
prompt = args.prompt
# input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
use_cache=True)
# start inference
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM INT4 optimizations
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
if local_rank == 0:
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -0,0 +1,12 @@
source bigdl-llm-init -t -g
export MASTER_ADDR=127.0.0.1
export CCL_ZE_IPC_EXCHANGE=sockets
if [[ -n $OMP_NUM_THREADS ]]; then
export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4))
else
export OMP_NUM_THREADS=$(($(nproc) / 4))
fi
torchrun --standalone \
--nnodes=1 \
--nproc-per-node 4 \
deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf"

View file

@ -45,6 +45,42 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger
def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
def is_linear_module(module):
in_features = None
out_features = None
mp_group = None
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
mp_group = None
result = True
else:
if is_deepspeed_available():
from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce
if isinstance(module, LinearLayer):
in_features = module.weight.shape[1]
out_features = module.weight.shape[0]
mp_group = None
result = True
elif isinstance(module, LinearAllreduce):
in_features = module.weight.shape[1]
out_features = module.weight.shape[0]
mp_group = module.mp_group
result = True
else:
result = False
else:
result = False
return result, (in_features, out_features, mp_group)
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
@ -54,17 +90,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if current_key_name is None:
current_key_name = []
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
is_linear, linear_args = is_linear_module(module)
if is_linear and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
in_features, out_features, mp_group = linear_args
with init_empty_weights():
new_linear = None
if qtype != ggml_tensor_qtype["fp16"]:
new_linear = LowBitLinear(
module.in_features,
module.out_features,
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
)
device_type = module.weight.data.device.type
@ -82,10 +121,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if module.in_features in [4096, 11008]:
# esimd fp16 path
new_linear = FP16Linear(
module.in_features,
module.out_features,
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
)
device_type = module.weight.data.device.type

View file

@ -328,7 +328,7 @@ class MatMulLowBit(torch.autograd.Function):
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True):
conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
@ -339,6 +339,7 @@ class LowBitLinear(nn.Linear):
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
self.conver_to_half = conver_to_half
self.mp_group = mp_group
def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype:
@ -378,6 +379,9 @@ class LowBitLinear(nn.Linear):
input_seq_size)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
else:
@ -400,7 +404,7 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True):
conver_to_half=True, mp_group=None):
super().__init__(input_features, output_features, bias)
self.in_len = input_features
self.out_len = output_features
@ -408,6 +412,7 @@ class FP16Linear(nn.Linear):
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
self.conver_to_half = conver_to_half
self.mp_group = mp_group
def forward(self, x: torch.Tensor):
if self.bias is not None and self.bias.dtype != x.dtype:
@ -442,6 +447,9 @@ class FP16Linear(nn.Linear):
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias

View file

@ -32,6 +32,7 @@
# limitations under the License.
import torch
import importlib
import torch.nn as nn
from typing import Optional, Tuple
import math
@ -58,10 +59,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def get_ipex_version():
if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch as ipex
return ipex.__version__
else:
return None
ipex_version = get_ipex_version()
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
if ipex_version == "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight,
self.variance_epsilon)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)