Fix Pipeline Parallel dtype (#11623)
This commit is contained in:
parent
d020ad6397
commit
d27a8cd08c
2 changed files with 5 additions and 3 deletions
|
|
@ -374,7 +374,7 @@ class _BaseAutoModelClass:
|
|||
"Please make sure you've called `init_pipeline_parallel()` "
|
||||
"and world size is the same as `pipeline_parallel_stages`")
|
||||
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages)
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"])
|
||||
import types
|
||||
# add pipeline_parallel_generate to pretrained model dynamically
|
||||
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
||||
|
|
@ -788,7 +788,7 @@ class _BaseAutoModelClass:
|
|||
|
||||
if pipeline_parallel_stages > 1:
|
||||
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages)
|
||||
model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype)
|
||||
import types
|
||||
# add pipeline_parallel_generate to pretrained model dynamically
|
||||
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size):
|
|||
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
||||
|
||||
|
||||
def pipeline_parallel(model, pipeline_parallel_stages):
|
||||
def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32):
|
||||
global num_layers
|
||||
if hasattr(model.config, 'num_hidden_layers'):
|
||||
num_layers = model.config.num_hidden_layers
|
||||
|
|
@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages):
|
|||
model.layer_start = layer_start
|
||||
model.layer_end = layer_end
|
||||
model.num_layers = num_layers
|
||||
if torch_dtype == torch.float16:
|
||||
model = model.half()
|
||||
model = model.to(f'xpu:{local_rank}')
|
||||
return model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue