From 067c7e8098690a0c3edb73df4dc1dbf80b58c218 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 25 Oct 2023 14:46:28 +0800 Subject: [PATCH] Support deepspeed AutoTP (#9230) * Support deepspeed * add test script * refactor convert * refine example * refine * refine example * fix style * refine example and adapte latest ipex * fix style --- .../GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 103 ++++++++++++++++++ .../llm/example/GPU/Deepspeed-AutoTP/run.sh | 12 ++ .../llm/src/bigdl/llm/transformers/convert.py | 50 ++++++++- .../bigdl/llm/transformers/low_bit_linear.py | 12 +- .../bigdl/llm/transformers/models/llama.py | 22 +++- 5 files changed, 190 insertions(+), 9 deletions(-) create mode 100644 python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py create mode 100644 python/llm/example/GPU/Deepspeed-AutoTP/run.sh diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py new file mode 100644 index 00000000..6b1309a7 --- /dev/null +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -0,0 +1,103 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +import transformers +import deepspeed + +local_rank = int(os.getenv("LOCAL_RANK", "0")) +world_size = int(os.getenv("WORLD_SIZE", "1")) + +from bigdl.llm import optimize_model + +import torch +import intel_extension_for_pytorch as ipex +import time +import argparse + +from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it +from transformers import LlamaTokenizer, AutoTokenizer + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + trust_remote_code=True, + use_cache=True) + + model = deepspeed.init_inference( + model, + mp_size=world_size, + dtype=torch.float16, + replace_method="auto", + ) + + # move model to cpu and use bigdl-llm `optimize_model` to convert the + # model into optimized low bit format + # convert the rest of the model into float16 to reduce allreduce traffic + model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16) + + # move model back to xpu + model = model.to(f'xpu:{local_rank}') + + print(model) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + prompt = args.prompt + # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n" + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}') + # ipex model needs a warmup, then inference time can be accurate + output = model.generate(input_ids, + max_new_tokens=args.n_predict, + use_cache=True) + + # start inference + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + do_sample=False, + max_new_tokens=args.n_predict) + torch.xpu.synchronize() + end = time.time() + if local_rank == 0: + output = output.cpu() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh new file mode 100644 index 00000000..972e8c9d --- /dev/null +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh @@ -0,0 +1,12 @@ +source bigdl-llm-init -t -g +export MASTER_ADDR=127.0.0.1 +export CCL_ZE_IPC_EXCHANGE=sockets +if [[ -n $OMP_NUM_THREADS ]]; then + export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4)) +else + export OMP_NUM_THREADS=$(($(nproc) / 4)) +fi +torchrun --standalone \ + --nnodes=1 \ + --nproc-per-node 4 \ + deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf" diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index a8cd05df..2acab799 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -45,6 +45,42 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger +def is_deepspeed_available(): + return importlib.util.find_spec("deepspeed") is not None + + +def is_linear_module(module): + + in_features = None + out_features = None + mp_group = None + + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + mp_group = None + result = True + else: + if is_deepspeed_available(): + from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce + if isinstance(module, LinearLayer): + in_features = module.weight.shape[1] + out_features = module.weight.shape[0] + mp_group = None + result = True + elif isinstance(module, LinearAllreduce): + in_features = module.weight.shape[1] + out_features = module.weight.shape[0] + mp_group = module.mp_group + result = True + else: + result = False + else: + result = False + + return result, (in_features, out_features, mp_group) + + def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear @@ -54,17 +90,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if current_key_name is None: current_key_name = [] - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + is_linear, linear_args = is_linear_module(module) + if is_linear and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None if qtype != ggml_tensor_qtype["fp16"]: new_linear = LowBitLinear( - module.in_features, - module.out_features, + in_features, + out_features, qtype, module.bias is not None, + mp_group=mp_group, ) device_type = module.weight.data.device.type @@ -82,10 +121,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if module.in_features in [4096, 11008]: # esimd fp16 path new_linear = FP16Linear( - module.in_features, - module.out_features, + in_features, + out_features, qtype, module.bias is not None, + mp_group=mp_group, ) device_type = module.weight.data.device.type diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index b1026ec7..c5b85312 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -328,7 +328,7 @@ class MatMulLowBit(torch.autograd.Function): class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True): + conver_to_half=True, mp_group=None): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, @@ -339,6 +339,7 @@ class LowBitLinear(nn.Linear): self.weight_length = self.out_len * self.in_len self.qtype = qtype self.conver_to_half = conver_to_half + self.mp_group = mp_group def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: @@ -378,6 +379,9 @@ class LowBitLinear(nn.Linear): input_seq_size) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) + if self.mp_group is not None: + from deepspeed import comm as dist + dist.inference_all_reduce(result, group=self.mp_group) if self.bias is not None: result += self.bias else: @@ -400,7 +404,7 @@ class LowBitLinear(nn.Linear): class FP16Linear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True): + conver_to_half=True, mp_group=None): super().__init__(input_features, output_features, bias) self.in_len = input_features self.out_len = output_features @@ -408,6 +412,7 @@ class FP16Linear(nn.Linear): self.weight_length = self.out_len * self.in_len self.qtype = qtype self.conver_to_half = conver_to_half + self.mp_group = mp_group def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: @@ -442,6 +447,9 @@ class FP16Linear(nn.Linear): new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) + if self.mp_group is not None: + from deepspeed import comm as dist + dist.inference_all_reduce(result, group=self.mp_group) if self.bias is not None: result += self.bias diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 0dd39ae6..94515ea0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -32,6 +32,7 @@ # limitations under the License. import torch +import importlib import torch.nn as nn from typing import Optional, Tuple import math @@ -58,10 +59,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +def get_ipex_version(): + + if importlib.util.find_spec("intel_extension_for_pytorch") is not None: + import intel_extension_for_pytorch as ipex + return ipex.__version__ + else: + return None + + +ipex_version = get_ipex_version() + + def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if ipex_version == "2.0.110+xpu": + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32)