add optimize model option (#9530)
This commit is contained in:
parent
6bec0faea5
commit
45820cf3b9
2 changed files with 5 additions and 4 deletions
|
|
@ -46,9 +46,10 @@ class BigDLLM(BaseLM):
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
load_in_8bit: Optional[bool] = False,
|
load_in_8bit: Optional[bool] = False,
|
||||||
trust_remote_code: Optional[bool] = False,
|
trust_remote_code: Optional[bool] = True,
|
||||||
load_in_low_bit=None,
|
load_in_low_bit=None,
|
||||||
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -58,8 +59,8 @@ class BigDLLM(BaseLM):
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
model = AutoModelForCausalLM.from_pretrained(pretrained,
|
model = AutoModelForCausalLM.from_pretrained(pretrained,
|
||||||
load_in_low_bit=load_in_low_bit,
|
load_in_low_bit=load_in_low_bit,
|
||||||
optimize_model=True,
|
optimize_model=kwargs.get('optimize_model', True),
|
||||||
trust_remote_code=True,
|
trust_remote_code=trust_remote_code,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
torch_dtype=_get_dtype(dtype))
|
torch_dtype=_get_dtype(dtype))
|
||||||
print(model) # print model to check precision
|
print(model) # print model to check precision
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ def main():
|
||||||
prec_arg = parse_precision(prec, args.model)
|
prec_arg = parse_precision(prec, args.model)
|
||||||
model_args = f"pretrained={args.pretrained},{prec_arg}"
|
model_args = f"pretrained={args.pretrained},{prec_arg}"
|
||||||
if len(args.model_args) > 0:
|
if len(args.model_args) > 0:
|
||||||
model_args += args.model_args
|
model_args = f"{model_args},{args.model_args}"
|
||||||
for task in args.tasks:
|
for task in args.tasks:
|
||||||
task_names=task_map.get(task, task).split(',')
|
task_names=task_map.get(task, task).split(',')
|
||||||
num_fewshot = task_to_n_few_shots.get(task, args.num_fewshot)
|
num_fewshot = task_to_n_few_shots.get(task, args.num_fewshot)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue