Refactor pipeline parallel device config (#11149)
* refactor pipeline parallel device config * meet comments * update example * add warnings and update code doc
This commit is contained in:
		
							parent
							
								
									62b2d8af6b
								
							
						
					
					
						commit
						33852bd23e
					
				
					 2 changed files with 35 additions and 21 deletions
				
			
		| 
						 | 
					@ -62,27 +62,8 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 load_in_4bit=True,
 | 
					                                                 load_in_4bit=True,
 | 
				
			||||||
                                                 optimize_model=True,
 | 
					                                                 optimize_model=True,
 | 
				
			||||||
                                                 trust_remote_code=True,
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
                                                 use_cache=True)
 | 
					                                                 use_cache=True,
 | 
				
			||||||
 | 
					                                                 pipeline_parallel_stages=args.gpu_num)
 | 
				
			||||||
    model_layers = ['model.embed_tokens']
 | 
					 | 
				
			||||||
    for i in range(model.config.num_hidden_layers):
 | 
					 | 
				
			||||||
        model_layers.append(f'model.layers.{i}')
 | 
					 | 
				
			||||||
    model_layers = model_layers + ['model.norm', 'lm_head']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    device_map = {}
 | 
					 | 
				
			||||||
    split_len = len(model_layers) // args.gpu_num
 | 
					 | 
				
			||||||
    for i in range(args.gpu_num):
 | 
					 | 
				
			||||||
        device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
 | 
					 | 
				
			||||||
        if i == args.gpu_num - 1:
 | 
					 | 
				
			||||||
            device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    from accelerate import dispatch_model
 | 
					 | 
				
			||||||
    model = dispatch_model(
 | 
					 | 
				
			||||||
        model,
 | 
					 | 
				
			||||||
        device_map=device_map,
 | 
					 | 
				
			||||||
        offload_dir=None,
 | 
					 | 
				
			||||||
        skip_keys=["past_key_value", "past_key_values"],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
 | 
				
			||||||
        self.to(origin_device)
 | 
					        self.to(origin_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pipeline_parallel(model, pipeline_parallel_stages):
 | 
				
			||||||
 | 
					    model_layers = ['model.embed_tokens']
 | 
				
			||||||
 | 
					    for i in range(model.config.num_hidden_layers):
 | 
				
			||||||
 | 
					        model_layers.append(f'model.layers.{i}')
 | 
				
			||||||
 | 
					    model_layers = model_layers + ['model.norm', 'lm_head']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device_map = {}
 | 
				
			||||||
 | 
					    split_len = len(model_layers) // pipeline_parallel_stages
 | 
				
			||||||
 | 
					    for i in range(pipeline_parallel_stages):
 | 
				
			||||||
 | 
					        device_map.update({key: f'xpu:{i}' for key in
 | 
				
			||||||
 | 
					                           model_layers[split_len * i: split_len * (i + 1)]})
 | 
				
			||||||
 | 
					        if i == pipeline_parallel_stages - 1:
 | 
				
			||||||
 | 
					            device_map.update({key: f'xpu:{i}' for key in
 | 
				
			||||||
 | 
					                               model_layers[split_len * (i + 1):]})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from accelerate import dispatch_model
 | 
				
			||||||
 | 
					    model = dispatch_model(
 | 
				
			||||||
 | 
					        model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _load_pre():
 | 
					def _load_pre():
 | 
				
			||||||
    from transformers import GPTJModel
 | 
					    from transformers import GPTJModel
 | 
				
			||||||
    from ipex_llm.transformers.models.gptj import gptj_model_new_init
 | 
					    from ipex_llm.transformers.models.gptj import gptj_model_new_init
 | 
				
			||||||
| 
						 | 
					@ -157,6 +179,9 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
					        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
				
			||||||
            Default to be False. If set to True, we will use sym_int8 for lm_head when
 | 
					            Default to be False. If set to True, we will use sym_int8 for lm_head when
 | 
				
			||||||
            load_in_low_bit is sym_int4 or asym_int4.
 | 
					            load_in_low_bit is sym_int4 or asym_int4.
 | 
				
			||||||
 | 
					        :param pipeline_parallel_stages: int value, the number of GPUs allocated for
 | 
				
			||||||
 | 
					            pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
 | 
				
			||||||
 | 
					            to run pipeline parallel inference on multiple GPUs.
 | 
				
			||||||
        :return: a model instance
 | 
					        :return: a model instance
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
					        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
				
			||||||
| 
						 | 
					@ -190,6 +215,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        optimize_model = kwargs.pop("optimize_model", True)
 | 
					        optimize_model = kwargs.pop("optimize_model", True)
 | 
				
			||||||
        user_quantization_config = kwargs.pop("quantization_config", None)
 | 
					        user_quantization_config = kwargs.pop("quantization_config", None)
 | 
				
			||||||
        speculative = kwargs.pop("speculative", False)
 | 
					        speculative = kwargs.pop("speculative", False)
 | 
				
			||||||
 | 
					        pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
 | 
				
			||||||
        torch_dtype = kwargs.pop("torch_dtype", None)
 | 
					        torch_dtype = kwargs.pop("torch_dtype", None)
 | 
				
			||||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
					        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -346,6 +372,13 @@ class _BaseAutoModelClass:
 | 
				
			||||||
            kwargs["embedding_qtype"] = embedding_qtype
 | 
					            kwargs["embedding_qtype"] = embedding_qtype
 | 
				
			||||||
            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
					            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if pipeline_parallel_stages > 1:
 | 
				
			||||||
 | 
					                if speculative:
 | 
				
			||||||
 | 
					                    invalidInputError(False,
 | 
				
			||||||
 | 
					                                      f"Please do not set speculative=True"
 | 
				
			||||||
 | 
					                                      f" when using pipeline_parallel_stages")
 | 
				
			||||||
 | 
					                model = pipeline_parallel(model, pipeline_parallel_stages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if speculative:
 | 
					            if speculative:
 | 
				
			||||||
                from .speculative import speculative_generate, clear_benchmarks,\
 | 
					                from .speculative import speculative_generate, clear_benchmarks,\
 | 
				
			||||||
                    _crop_past_key_values
 | 
					                    _crop_past_key_values
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue