From d27a8cd08c840aa17694b8d4e806f0f32bd05bb2 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:07:40 +0800 Subject: [PATCH] Fix Pipeline Parallel dtype (#11623) --- python/llm/src/ipex_llm/transformers/model.py | 4 ++-- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index e40c56d6..d9dd6354 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -374,7 +374,7 @@ class _BaseAutoModelClass: "Please make sure you've called `init_pipeline_parallel()` " "and world size is the same as `pipeline_parallel_stages`") from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate - model = pipeline_parallel(model, pipeline_parallel_stages) + model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"]) import types # add pipeline_parallel_generate to pretrained model dynamically model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, @@ -788,7 +788,7 @@ class _BaseAutoModelClass: if pipeline_parallel_stages > 1: from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate - model = pipeline_parallel(model, pipeline_parallel_stages) + model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype) import types # add pipeline_parallel_generate to pretrained model dynamically model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index bbe04f12..e8dff53d 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size): os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0" -def pipeline_parallel(model, pipeline_parallel_stages): +def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32): global num_layers if hasattr(model.config, 'num_hidden_layers'): num_layers = model.config.num_hidden_layers @@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages): model.layer_start = layer_start model.layer_end = layer_end model.num_layers = num_layers + if torch_dtype == torch.float16: + model = model.half() model = model.to(f'xpu:{local_rank}') return model