Refactor pipeline parallel device config (#11149)

* refactor pipeline parallel device config

* meet comments

* update example

* add warnings and update code doc
This commit is contained in:
SONG Ge 2024-05-28 16:52:46 +08:00 committed by GitHub
parent 62b2d8af6b
commit 33852bd23e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 35 additions and 21 deletions

View file

@ -62,27 +62,8 @@ if __name__ == '__main__':
load_in_4bit=True, load_in_4bit=True,
optimize_model=True, optimize_model=True,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True,
pipeline_parallel_stages=args.gpu_num)
model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']
device_map = {}
split_len = len(model_layers) // args.gpu_num
for i in range(args.gpu_num):
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
if i == args.gpu_num - 1:
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=None,
skip_keys=["past_key_value", "past_key_values"],
)
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

View file

@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
self.to(origin_device) self.to(origin_device)
def pipeline_parallel(model, pipeline_parallel_stages):
model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']
device_map = {}
split_len = len(model_layers) // pipeline_parallel_stages
for i in range(pipeline_parallel_stages):
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * i: split_len * (i + 1)]})
if i == pipeline_parallel_stages - 1:
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * (i + 1):]})
from accelerate import dispatch_model
model = dispatch_model(
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
)
return model
def _load_pre(): def _load_pre():
from transformers import GPTJModel from transformers import GPTJModel
from ipex_llm.transformers.models.gptj import gptj_model_new_init from ipex_llm.transformers.models.gptj import gptj_model_new_init
@ -157,6 +179,9 @@ class _BaseAutoModelClass:
:param mixed_precision: boolean value, Whether to use mixed precision quantization. :param mixed_precision: boolean value, Whether to use mixed precision quantization.
Default to be False. If set to True, we will use sym_int8 for lm_head when Default to be False. If set to True, we will use sym_int8 for lm_head when
load_in_low_bit is sym_int4 or asym_int4. load_in_low_bit is sym_int4 or asym_int4.
:param pipeline_parallel_stages: int value, the number of GPUs allocated for
pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
to run pipeline parallel inference on multiple GPUs.
:return: a model instance :return: a model instance
""" """
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
@ -190,6 +215,7 @@ class _BaseAutoModelClass:
optimize_model = kwargs.pop("optimize_model", True) optimize_model = kwargs.pop("optimize_model", True)
user_quantization_config = kwargs.pop("quantization_config", None) user_quantization_config = kwargs.pop("quantization_config", None)
speculative = kwargs.pop("speculative", False) speculative = kwargs.pop("speculative", False)
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
embedding_qtype = kwargs.pop("embedding_qtype", None) embedding_qtype = kwargs.pop("embedding_qtype", None)
@ -346,6 +372,13 @@ class _BaseAutoModelClass:
kwargs["embedding_qtype"] = embedding_qtype kwargs["embedding_qtype"] = embedding_qtype
model = cls.load_convert(q_k, optimize_model, *args, **kwargs) model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
if pipeline_parallel_stages > 1:
if speculative:
invalidInputError(False,
f"Please do not set speculative=True"
f" when using pipeline_parallel_stages")
model = pipeline_parallel(model, pipeline_parallel_stages)
if speculative: if speculative:
from .speculative import speculative_generate, clear_benchmarks,\ from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values _crop_past_key_values