fix multiple gpu usage (#9459)

This commit is contained in:
Chen, Zhentao 2023-11-14 17:06:27 +08:00 committed by GitHub
parent d19ca21957
commit dbbdb53a18

View file

@ -54,7 +54,7 @@ class BigDLLM(BaseLM):
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, (int,str)) assert isinstance(batch_size, (int,str))
if device == 'xpu': if 'xpu' in device:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
model = AutoModelForCausalLM.from_pretrained(pretrained, model = AutoModelForCausalLM.from_pretrained(pretrained,
load_in_low_bit=load_in_low_bit, load_in_low_bit=load_in_low_bit,
@ -118,4 +118,4 @@ class BigDLLM(BaseLM):
return res return res
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.model(context, max_tokens=max_length, stop=["Q:", "\n"], echo=True) return self.model(context, max_tokens=max_length, stop=["Q:", "\n"], echo=True)