fix multiple gpu usage (#9459)
This commit is contained in:
parent
d19ca21957
commit
dbbdb53a18
1 changed files with 2 additions and 2 deletions
|
|
@ -54,7 +54,7 @@ class BigDLLM(BaseLM):
|
|||
|
||||
assert isinstance(pretrained, str)
|
||||
assert isinstance(batch_size, (int,str))
|
||||
if device == 'xpu':
|
||||
if 'xpu' in device:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
|
|
@ -118,4 +118,4 @@ class BigDLLM(BaseLM):
|
|||
return res
|
||||
|
||||
def _model_generate(self, context, max_length, eos_token_id):
|
||||
return self.model(context, max_tokens=max_length, stop=["Q:", "\n"], echo=True)
|
||||
return self.model(context, max_tokens=max_length, stop=["Q:", "\n"], echo=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue