Fix Pipeline Parallel dtype (#11623)

This commit is contained in:
Xiangyu Tian 2024-07-19 13:07:40 +08:00 committed by GitHub
parent d020ad6397
commit d27a8cd08c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View file

@ -374,7 +374,7 @@ class _BaseAutoModelClass:
"Please make sure you've called `init_pipeline_parallel()` "
"and world size is the same as `pipeline_parallel_stages`")
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages)
model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"])
import types
# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
@ -788,7 +788,7 @@ class _BaseAutoModelClass:
if pipeline_parallel_stages > 1:
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages)
model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype)
import types
# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,

View file

@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size):
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
def pipeline_parallel(model, pipeline_parallel_stages):
def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32):
global num_layers
if hasattr(model.config, 'num_hidden_layers'):
num_layers = model.config.num_hidden_layers
@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages):
model.layer_start = layer_start
model.layer_end = layer_end
model.num_layers = num_layers
if torch_dtype == torch.float16:
model = model.half()
model = model.to(f'xpu:{local_rank}')
return model