add optimize model option (#9530)
This commit is contained in:
parent
6bec0faea5
commit
45820cf3b9
2 changed files with 5 additions and 4 deletions
|
|
@ -46,9 +46,10 @@ class BigDLLM(BaseLM):
|
|||
tokenizer=None,
|
||||
batch_size=1,
|
||||
load_in_8bit: Optional[bool] = False,
|
||||
trust_remote_code: Optional[bool] = False,
|
||||
trust_remote_code: Optional[bool] = True,
|
||||
load_in_low_bit=None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -58,8 +59,8 @@ class BigDLLM(BaseLM):
|
|||
import intel_extension_for_pytorch as ipex
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True,
|
||||
optimize_model=kwargs.get('optimize_model', True),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_cache=True,
|
||||
torch_dtype=_get_dtype(dtype))
|
||||
print(model) # print model to check precision
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def main():
|
|||
prec_arg = parse_precision(prec, args.model)
|
||||
model_args = f"pretrained={args.pretrained},{prec_arg}"
|
||||
if len(args.model_args) > 0:
|
||||
model_args += args.model_args
|
||||
model_args = f"{model_args},{args.model_args}"
|
||||
for task in args.tasks:
|
||||
task_names=task_map.get(task, task).split(',')
|
||||
num_fewshot = task_to_n_few_shots.get(task, args.num_fewshot)
|
||||
|
|
|
|||
Loading…
Reference in a new issue