add xpu support (#10419)
This commit is contained in:
parent
7d29765092
commit
1c0f7ed3fa
1 changed files with 7 additions and 1 deletions
|
|
@ -83,6 +83,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
cls,
|
cls,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model_kwargs: Optional[dict] = None,
|
model_kwargs: Optional[dict] = None,
|
||||||
|
device_map: str = 'cpu',
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -117,6 +118,10 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
if 'xpu' in device_map:
|
||||||
|
model = model.to(device_map)
|
||||||
|
|
||||||
if "trust_remote_code" in _model_kwargs:
|
if "trust_remote_code" in _model_kwargs:
|
||||||
_model_kwargs = {
|
_model_kwargs = {
|
||||||
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
|
|
@ -145,7 +150,8 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T]
|
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T]
|
||||||
embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N]
|
input_ids = input_ids.to(self.model.device)
|
||||||
|
embeddings = self.model(input_ids, return_dict=False)[0].cpu() # shape: [1, T, N]
|
||||||
embeddings = embeddings.squeeze(0).detach().numpy()
|
embeddings = embeddings.squeeze(0).detach().numpy()
|
||||||
embeddings = np.mean(embeddings, axis=0)
|
embeddings = np.mean(embeddings, axis=0)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue