Fix Pipeline Parallel dtype (#11623)
This commit is contained in:
parent
d020ad6397
commit
d27a8cd08c
2 changed files with 5 additions and 3 deletions
|
|
@ -374,7 +374,7 @@ class _BaseAutoModelClass:
|
||||||
"Please make sure you've called `init_pipeline_parallel()` "
|
"Please make sure you've called `init_pipeline_parallel()` "
|
||||||
"and world size is the same as `pipeline_parallel_stages`")
|
"and world size is the same as `pipeline_parallel_stages`")
|
||||||
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
||||||
model = pipeline_parallel(model, pipeline_parallel_stages)
|
model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"])
|
||||||
import types
|
import types
|
||||||
# add pipeline_parallel_generate to pretrained model dynamically
|
# add pipeline_parallel_generate to pretrained model dynamically
|
||||||
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
||||||
|
|
@ -788,7 +788,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
if pipeline_parallel_stages > 1:
|
if pipeline_parallel_stages > 1:
|
||||||
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
|
||||||
model = pipeline_parallel(model, pipeline_parallel_stages)
|
model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype)
|
||||||
import types
|
import types
|
||||||
# add pipeline_parallel_generate to pretrained model dynamically
|
# add pipeline_parallel_generate to pretrained model dynamically
|
||||||
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size):
|
||||||
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def pipeline_parallel(model, pipeline_parallel_stages):
|
def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32):
|
||||||
global num_layers
|
global num_layers
|
||||||
if hasattr(model.config, 'num_hidden_layers'):
|
if hasattr(model.config, 'num_hidden_layers'):
|
||||||
num_layers = model.config.num_hidden_layers
|
num_layers = model.config.num_hidden_layers
|
||||||
|
|
@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages):
|
||||||
model.layer_start = layer_start
|
model.layer_start = layer_start
|
||||||
model.layer_end = layer_end
|
model.layer_end = layer_end
|
||||||
model.num_layers = num_layers
|
model.num_layers = num_layers
|
||||||
|
if torch_dtype == torch.float16:
|
||||||
|
model = model.half()
|
||||||
model = model.to(f'xpu:{local_rank}')
|
model = model.to(f'xpu:{local_rank}')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue