Refactor pipeline parallel device config (#11149)
* refactor pipeline parallel device config * meet comments * update example * add warnings and update code doc
This commit is contained in:
parent
62b2d8af6b
commit
33852bd23e
2 changed files with 35 additions and 21 deletions
|
|
@ -62,27 +62,8 @@ if __name__ == '__main__':
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
optimize_model=True,
|
optimize_model=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True,
|
||||||
|
pipeline_parallel_stages=args.gpu_num)
|
||||||
model_layers = ['model.embed_tokens']
|
|
||||||
for i in range(model.config.num_hidden_layers):
|
|
||||||
model_layers.append(f'model.layers.{i}')
|
|
||||||
model_layers = model_layers + ['model.norm', 'lm_head']
|
|
||||||
|
|
||||||
device_map = {}
|
|
||||||
split_len = len(model_layers) // args.gpu_num
|
|
||||||
for i in range(args.gpu_num):
|
|
||||||
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
|
|
||||||
if i == args.gpu_num - 1:
|
|
||||||
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
|
|
||||||
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
model = dispatch_model(
|
|
||||||
model,
|
|
||||||
device_map=device_map,
|
|
||||||
offload_dir=None,
|
|
||||||
skip_keys=["past_key_value", "past_key_values"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
self.to(origin_device)
|
self.to(origin_device)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_parallel(model, pipeline_parallel_stages):
|
||||||
|
model_layers = ['model.embed_tokens']
|
||||||
|
for i in range(model.config.num_hidden_layers):
|
||||||
|
model_layers.append(f'model.layers.{i}')
|
||||||
|
model_layers = model_layers + ['model.norm', 'lm_head']
|
||||||
|
|
||||||
|
device_map = {}
|
||||||
|
split_len = len(model_layers) // pipeline_parallel_stages
|
||||||
|
for i in range(pipeline_parallel_stages):
|
||||||
|
device_map.update({key: f'xpu:{i}' for key in
|
||||||
|
model_layers[split_len * i: split_len * (i + 1)]})
|
||||||
|
if i == pipeline_parallel_stages - 1:
|
||||||
|
device_map.update({key: f'xpu:{i}' for key in
|
||||||
|
model_layers[split_len * (i + 1):]})
|
||||||
|
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
model = dispatch_model(
|
||||||
|
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _load_pre():
|
def _load_pre():
|
||||||
from transformers import GPTJModel
|
from transformers import GPTJModel
|
||||||
from ipex_llm.transformers.models.gptj import gptj_model_new_init
|
from ipex_llm.transformers.models.gptj import gptj_model_new_init
|
||||||
|
|
@ -157,6 +179,9 @@ class _BaseAutoModelClass:
|
||||||
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
||||||
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
||||||
load_in_low_bit is sym_int4 or asym_int4.
|
load_in_low_bit is sym_int4 or asym_int4.
|
||||||
|
:param pipeline_parallel_stages: int value, the number of GPUs allocated for
|
||||||
|
pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
|
||||||
|
to run pipeline parallel inference on multiple GPUs.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
|
|
@ -190,6 +215,7 @@ class _BaseAutoModelClass:
|
||||||
optimize_model = kwargs.pop("optimize_model", True)
|
optimize_model = kwargs.pop("optimize_model", True)
|
||||||
user_quantization_config = kwargs.pop("quantization_config", None)
|
user_quantization_config = kwargs.pop("quantization_config", None)
|
||||||
speculative = kwargs.pop("speculative", False)
|
speculative = kwargs.pop("speculative", False)
|
||||||
|
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||||
|
|
||||||
|
|
@ -346,6 +372,13 @@ class _BaseAutoModelClass:
|
||||||
kwargs["embedding_qtype"] = embedding_qtype
|
kwargs["embedding_qtype"] = embedding_qtype
|
||||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||||
|
|
||||||
|
if pipeline_parallel_stages > 1:
|
||||||
|
if speculative:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Please do not set speculative=True"
|
||||||
|
f" when using pipeline_parallel_stages")
|
||||||
|
model = pipeline_parallel(model, pipeline_parallel_stages)
|
||||||
|
|
||||||
if speculative:
|
if speculative:
|
||||||
from .speculative import speculative_generate, clear_benchmarks,\
|
from .speculative import speculative_generate, clear_benchmarks,\
|
||||||
_crop_past_key_values
|
_crop_past_key_values
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue