Refactor pipeline parallel device config (#11149)
* refactor pipeline parallel device config * meet comments * update example * add warnings and update code doc
This commit is contained in:
		
							parent
							
								
									62b2d8af6b
								
							
						
					
					
						commit
						33852bd23e
					
				
					 2 changed files with 35 additions and 21 deletions
				
			
		| 
						 | 
				
			
			@ -62,27 +62,8 @@ if __name__ == '__main__':
 | 
			
		|||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 optimize_model=True,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
 | 
			
		||||
    model_layers = ['model.embed_tokens']
 | 
			
		||||
    for i in range(model.config.num_hidden_layers):
 | 
			
		||||
        model_layers.append(f'model.layers.{i}')
 | 
			
		||||
    model_layers = model_layers + ['model.norm', 'lm_head']
 | 
			
		||||
 | 
			
		||||
    device_map = {}
 | 
			
		||||
    split_len = len(model_layers) // args.gpu_num
 | 
			
		||||
    for i in range(args.gpu_num):
 | 
			
		||||
        device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
 | 
			
		||||
        if i == args.gpu_num - 1:
 | 
			
		||||
            device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})
 | 
			
		||||
 | 
			
		||||
    from accelerate import dispatch_model
 | 
			
		||||
    model = dispatch_model(
 | 
			
		||||
        model,
 | 
			
		||||
        device_map=device_map,
 | 
			
		||||
        offload_dir=None,
 | 
			
		||||
        skip_keys=["past_key_value", "past_key_values"],
 | 
			
		||||
    )
 | 
			
		||||
                                                 use_cache=True,
 | 
			
		||||
                                                 pipeline_parallel_stages=args.gpu_num)
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
 | 
			
		|||
        self.to(origin_device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pipeline_parallel(model, pipeline_parallel_stages):
 | 
			
		||||
    model_layers = ['model.embed_tokens']
 | 
			
		||||
    for i in range(model.config.num_hidden_layers):
 | 
			
		||||
        model_layers.append(f'model.layers.{i}')
 | 
			
		||||
    model_layers = model_layers + ['model.norm', 'lm_head']
 | 
			
		||||
 | 
			
		||||
    device_map = {}
 | 
			
		||||
    split_len = len(model_layers) // pipeline_parallel_stages
 | 
			
		||||
    for i in range(pipeline_parallel_stages):
 | 
			
		||||
        device_map.update({key: f'xpu:{i}' for key in
 | 
			
		||||
                           model_layers[split_len * i: split_len * (i + 1)]})
 | 
			
		||||
        if i == pipeline_parallel_stages - 1:
 | 
			
		||||
            device_map.update({key: f'xpu:{i}' for key in
 | 
			
		||||
                               model_layers[split_len * (i + 1):]})
 | 
			
		||||
 | 
			
		||||
    from accelerate import dispatch_model
 | 
			
		||||
    model = dispatch_model(
 | 
			
		||||
        model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
 | 
			
		||||
    )
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_pre():
 | 
			
		||||
    from transformers import GPTJModel
 | 
			
		||||
    from ipex_llm.transformers.models.gptj import gptj_model_new_init
 | 
			
		||||
| 
						 | 
				
			
			@ -157,6 +179,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
			
		||||
            Default to be False. If set to True, we will use sym_int8 for lm_head when
 | 
			
		||||
            load_in_low_bit is sym_int4 or asym_int4.
 | 
			
		||||
        :param pipeline_parallel_stages: int value, the number of GPUs allocated for
 | 
			
		||||
            pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
 | 
			
		||||
            to run pipeline parallel inference on multiple GPUs.
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
			
		||||
| 
						 | 
				
			
			@ -190,6 +215,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        optimize_model = kwargs.pop("optimize_model", True)
 | 
			
		||||
        user_quantization_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        speculative = kwargs.pop("speculative", False)
 | 
			
		||||
        pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
 | 
			
		||||
        torch_dtype = kwargs.pop("torch_dtype", None)
 | 
			
		||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -346,6 +372,13 @@ class _BaseAutoModelClass:
 | 
			
		|||
            kwargs["embedding_qtype"] = embedding_qtype
 | 
			
		||||
            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            if pipeline_parallel_stages > 1:
 | 
			
		||||
                if speculative:
 | 
			
		||||
                    invalidInputError(False,
 | 
			
		||||
                                      f"Please do not set speculative=True"
 | 
			
		||||
                                      f" when using pipeline_parallel_stages")
 | 
			
		||||
                model = pipeline_parallel(model, pipeline_parallel_stages)
 | 
			
		||||
 | 
			
		||||
            if speculative:
 | 
			
		||||
                from .speculative import speculative_generate, clear_benchmarks,\
 | 
			
		||||
                    _crop_past_key_values
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue