LLM: support int8 and tmp_path for from_pretrained (#8338)
				
					
				
			This commit is contained in:
		
							parent
							
								
									b30aa49c4e
								
							
						
					
					
						commit
						f7f4e65788
					
				
					 1 changed files with 10 additions and 4 deletions
				
			
		| 
						 | 
					@ -36,6 +36,7 @@ class AutoModelForCausalLM:
 | 
				
			||||||
                        model_family: str = 'llama',
 | 
					                        model_family: str = 'llama',
 | 
				
			||||||
                        dtype: str = 'int4',
 | 
					                        dtype: str = 'int4',
 | 
				
			||||||
                        cache_dir: str = './',
 | 
					                        cache_dir: str = './',
 | 
				
			||||||
 | 
					                        tmp_path: str = None,
 | 
				
			||||||
                        **kwargs):
 | 
					                        **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        :param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
 | 
					        :param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
 | 
				
			||||||
| 
						 | 
					@ -50,10 +51,14 @@ class AutoModelForCausalLM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param model_family: the model family of the pretrained checkpoint.
 | 
					        :param model_family: the model family of the pretrained checkpoint.
 | 
				
			||||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
 | 
					               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
 | 
				
			||||||
        :param dtype: (optional) the data type for weight. Currently we only support ``"int4"``
 | 
					        :param dtype: Which quantized precision will be converted.
 | 
				
			||||||
 | 
					                Now only `int4` and `int8` are supported, and `int8` only works for `llama`
 | 
				
			||||||
 | 
					                and `gptneox`.
 | 
				
			||||||
        :param cache_dir: (optional) this parameter will only be used when
 | 
					        :param cache_dir: (optional) this parameter will only be used when
 | 
				
			||||||
               ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
 | 
					               ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
 | 
				
			||||||
               It indicates the saving path for the converted low precision model.
 | 
					               It indicates the saving path for the converted low precision model.
 | 
				
			||||||
 | 
					        :param tmp_path: (optional) Which path to store the intermediate fp16 model during the
 | 
				
			||||||
 | 
					               conversion process. Default to `None` so that intermediate model will not be saved.
 | 
				
			||||||
        :param **kwargs: keyword arguments which will be passed to the model instance
 | 
					        :param **kwargs: keyword arguments which will be passed to the model instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :return: a model instance
 | 
					        :return: a model instance
 | 
				
			||||||
| 
						 | 
					@ -61,8 +66,8 @@ class AutoModelForCausalLM:
 | 
				
			||||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
 | 
					        invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
 | 
				
			||||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom', "
 | 
					                          "Now we only support model family: 'llama', 'gptneox', 'bloom', "
 | 
				
			||||||
                          "'{}' is not in the list.".format(model_family))
 | 
					                          "'{}' is not in the list.".format(model_family))
 | 
				
			||||||
        invalidInputError(dtype.lower() == 'int4',
 | 
					        invalidInputError(dtype.lower() in ['int4', 'int8'],
 | 
				
			||||||
                          "Now we only support int4 as date type for weight")
 | 
					                          "Now we only support int4 and int8 as date type for weight")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # check whether pretrained_model_name_or_path exists.
 | 
					        # check whether pretrained_model_name_or_path exists.
 | 
				
			||||||
        # if not, it is likely that the user wants to pass in the repo id.
 | 
					        # if not, it is likely that the user wants to pass in the repo id.
 | 
				
			||||||
| 
						 | 
					@ -93,7 +98,8 @@ class AutoModelForCausalLM:
 | 
				
			||||||
            ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
 | 
					            ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
 | 
				
			||||||
                                            output_path=cache_dir,
 | 
					                                            output_path=cache_dir,
 | 
				
			||||||
                                            model_family=model_family,
 | 
					                                            model_family=model_family,
 | 
				
			||||||
                                            dtype=dtype)
 | 
					                                            dtype=dtype,
 | 
				
			||||||
 | 
					                                            tmp_path=tmp_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model_family == 'llama':
 | 
					        if model_family == 'llama':
 | 
				
			||||||
            from bigdl.llm.ggml.model.llama import Llama
 | 
					            from bigdl.llm.ggml.model.llama import Llama
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue