add xpu support (#10419)

This commit is contained in:
dingbaorong 2024-03-14 17:13:48 +08:00 committed by GitHub
parent 7d29765092
commit 1c0f7ed3fa

View file

@ -83,6 +83,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
cls,
model_id: str,
model_kwargs: Optional[dict] = None,
device_map: str = 'cpu',
**kwargs: Any,
):
"""
@ -117,6 +118,10 @@ class TransformersEmbeddings(BaseModel, Embeddings):
model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
# TODO: may refactore this code in the future
if 'xpu' in device_map:
model = model.to(device_map)
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
@ -145,7 +150,8 @@ class TransformersEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T]
embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N]
input_ids = input_ids.to(self.model.device)
embeddings = self.model(input_ids, return_dict=False)[0].cpu() # shape: [1, T, N]
embeddings = embeddings.squeeze(0).detach().numpy()
embeddings = np.mean(embeddings, axis=0)
return embeddings