Fix Pipeline Parallel dtype (#11623)
This commit is contained in:
		
							parent
							
								
									d020ad6397
								
							
						
					
					
						commit
						d27a8cd08c
					
				
					 2 changed files with 5 additions and 3 deletions
				
			
		| 
						 | 
					@ -374,7 +374,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                                  "Please make sure you've called `init_pipeline_parallel()` "
 | 
					                                  "Please make sure you've called `init_pipeline_parallel()` "
 | 
				
			||||||
                                  "and world size is the same as `pipeline_parallel_stages`")
 | 
					                                  "and world size is the same as `pipeline_parallel_stages`")
 | 
				
			||||||
                from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
 | 
					                from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
 | 
				
			||||||
                model = pipeline_parallel(model, pipeline_parallel_stages)
 | 
					                model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"])
 | 
				
			||||||
                import types
 | 
					                import types
 | 
				
			||||||
                # add pipeline_parallel_generate to pretrained model dynamically
 | 
					                # add pipeline_parallel_generate to pretrained model dynamically
 | 
				
			||||||
                model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
 | 
					                model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
 | 
				
			||||||
| 
						 | 
					@ -788,7 +788,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if pipeline_parallel_stages > 1:
 | 
					        if pipeline_parallel_stages > 1:
 | 
				
			||||||
            from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
 | 
					            from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
 | 
				
			||||||
            model = pipeline_parallel(model, pipeline_parallel_stages)
 | 
					            model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype)
 | 
				
			||||||
            import types
 | 
					            import types
 | 
				
			||||||
            # add pipeline_parallel_generate to pretrained model dynamically
 | 
					            # add pipeline_parallel_generate to pretrained model dynamically
 | 
				
			||||||
            model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
 | 
					            model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size):
 | 
				
			||||||
        os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
 | 
					        os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def pipeline_parallel(model, pipeline_parallel_stages):
 | 
					def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32):
 | 
				
			||||||
    global num_layers
 | 
					    global num_layers
 | 
				
			||||||
    if hasattr(model.config, 'num_hidden_layers'):
 | 
					    if hasattr(model.config, 'num_hidden_layers'):
 | 
				
			||||||
        num_layers = model.config.num_hidden_layers
 | 
					        num_layers = model.config.num_hidden_layers
 | 
				
			||||||
| 
						 | 
					@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages):
 | 
				
			||||||
    model.layer_start = layer_start
 | 
					    model.layer_start = layer_start
 | 
				
			||||||
    model.layer_end = layer_end
 | 
					    model.layer_end = layer_end
 | 
				
			||||||
    model.num_layers = num_layers
 | 
					    model.num_layers = num_layers
 | 
				
			||||||
 | 
					    if torch_dtype == torch.float16:
 | 
				
			||||||
 | 
					        model = model.half()
 | 
				
			||||||
    model = model.to(f'xpu:{local_rank}')
 | 
					    model = model.to(f'xpu:{local_rank}')
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue