add gptq option for ppl test (#11921)

* feat:add gptq for ppl

* fix: add an empty line

* fix: add an empty line

* fix: remove an empty line

* Resolve comments

* Resolve comments

* Resolve comments
This commit is contained in:
Chu,Youcheng 2024-08-30 13:43:48 +08:00 committed by GitHub
parent 1e8c87050f
commit ae7302a654
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -38,12 +38,24 @@ args = parser.parse_args()
if args.precision == "fp16": # ipex fp16
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(args.model_path, use_cache=args.use_cache, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_path,
use_cache=args.use_cache,
trust_remote_code=True)
model = model.half()
elif 'gptq' in args.model_path.lower(): # ipex-llm gptq
from ipex_llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(args.model_path,
load_in_4bit=True,
torch_dtype=torch.float,
use_cache=args.use_cache,
trust_remote_code=True)
else: # ipex-llm
from ipex_llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(args.model_path, load_in_low_bit=args.precision,
use_cache=args.use_cache, trust_remote_code=True, mixed_precision= args.mixed_precision)
model = AutoModelForCausalLM.from_pretrained(args.model_path,
load_in_low_bit=args.precision,
use_cache=args.use_cache,
trust_remote_code=True,
mixed_precision=args.mixed_precision)
model = model.half()
model = model.to(args.device)
model = model.eval()