Refactor pipeline parallel device config (#11149)
* refactor pipeline parallel device config * meet comments * update example * add warnings and update code doc
This commit is contained in:
parent
62b2d8af6b
commit
33852bd23e
2 changed files with 35 additions and 21 deletions
|
|
@ -62,27 +62,8 @@ if __name__ == '__main__':
|
|||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
|
||||
model_layers = ['model.embed_tokens']
|
||||
for i in range(model.config.num_hidden_layers):
|
||||
model_layers.append(f'model.layers.{i}')
|
||||
model_layers = model_layers + ['model.norm', 'lm_head']
|
||||
|
||||
device_map = {}
|
||||
split_len = len(model_layers) // args.gpu_num
|
||||
for i in range(args.gpu_num):
|
||||
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
|
||||
if i == args.gpu_num - 1:
|
||||
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
|
||||
|
||||
from accelerate import dispatch_model
|
||||
model = dispatch_model(
|
||||
model,
|
||||
device_map=device_map,
|
||||
offload_dir=None,
|
||||
skip_keys=["past_key_value", "past_key_values"],
|
||||
)
|
||||
use_cache=True,
|
||||
pipeline_parallel_stages=args.gpu_num)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
|
|
|||
|
|
@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
|
|||
self.to(origin_device)
|
||||
|
||||
|
||||
def pipeline_parallel(model, pipeline_parallel_stages):
|
||||
model_layers = ['model.embed_tokens']
|
||||
for i in range(model.config.num_hidden_layers):
|
||||
model_layers.append(f'model.layers.{i}')
|
||||
model_layers = model_layers + ['model.norm', 'lm_head']
|
||||
|
||||
device_map = {}
|
||||
split_len = len(model_layers) // pipeline_parallel_stages
|
||||
for i in range(pipeline_parallel_stages):
|
||||
device_map.update({key: f'xpu:{i}' for key in
|
||||
model_layers[split_len * i: split_len * (i + 1)]})
|
||||
if i == pipeline_parallel_stages - 1:
|
||||
device_map.update({key: f'xpu:{i}' for key in
|
||||
model_layers[split_len * (i + 1):]})
|
||||
|
||||
from accelerate import dispatch_model
|
||||
model = dispatch_model(
|
||||
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _load_pre():
|
||||
from transformers import GPTJModel
|
||||
from ipex_llm.transformers.models.gptj import gptj_model_new_init
|
||||
|
|
@ -157,6 +179,9 @@ class _BaseAutoModelClass:
|
|||
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
||||
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
||||
load_in_low_bit is sym_int4 or asym_int4.
|
||||
:param pipeline_parallel_stages: int value, the number of GPUs allocated for
|
||||
pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
|
||||
to run pipeline parallel inference on multiple GPUs.
|
||||
:return: a model instance
|
||||
"""
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
|
|
@ -190,6 +215,7 @@ class _BaseAutoModelClass:
|
|||
optimize_model = kwargs.pop("optimize_model", True)
|
||||
user_quantization_config = kwargs.pop("quantization_config", None)
|
||||
speculative = kwargs.pop("speculative", False)
|
||||
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||
|
||||
|
|
@ -346,6 +372,13 @@ class _BaseAutoModelClass:
|
|||
kwargs["embedding_qtype"] = embedding_qtype
|
||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||
|
||||
if pipeline_parallel_stages > 1:
|
||||
if speculative:
|
||||
invalidInputError(False,
|
||||
f"Please do not set speculative=True"
|
||||
f" when using pipeline_parallel_stages")
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages)
|
||||
|
||||
if speculative:
|
||||
from .speculative import speculative_generate, clear_benchmarks,\
|
||||
_crop_past_key_values
|
||||
|
|
|
|||
Loading…
Reference in a new issue