Add embeddings (#10931)
This commit is contained in:
parent
d649236321
commit
49ab5a2b0e
2 changed files with 179 additions and 1 deletions
|
|
@ -19,6 +19,9 @@ A model worker that executes the model based on BigDL-LLM.
|
|||
Relies on load_model method
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import gc
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
|
|
@ -31,6 +34,7 @@ from fastapi.concurrency import run_in_threadpool
|
|||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
from fastchat.serve.model_worker import (
|
||||
logger,
|
||||
|
|
@ -63,6 +67,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
|||
device: str = "cpu",
|
||||
no_register: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
embed_in_truncate: bool = False,
|
||||
speculative: bool = False,
|
||||
stream_interval: int = 4,
|
||||
):
|
||||
|
|
@ -93,9 +98,158 @@ class BigDLLLMWorker(BaseModelWorker):
|
|||
)
|
||||
self.stream_interval = stream_interval
|
||||
self.context_len = get_context_length(self.model.config)
|
||||
self.embed_in_truncate = embed_in_truncate
|
||||
if not no_register:
|
||||
self.init_heart_beat()
|
||||
|
||||
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
|
||||
if model_type_dict.get("is_bert"):
|
||||
model_output = self.model(input_ids)
|
||||
if model_type_dict.get("is_robert"):
|
||||
data = model_output.last_hidden_state
|
||||
else:
|
||||
data = model_output[0]
|
||||
elif model_type_dict.get("is_t5"):
|
||||
model_output = self.model(input_ids, decoder_input_ids=input_ids)
|
||||
data = model_output.encoder_last_hidden_state
|
||||
else:
|
||||
model_output = self.model(input_ids, output_hidden_states=True)
|
||||
if model_type_dict.get("is_chatglm"):
|
||||
data = model_output.hidden_states[-1].transpose(0, 1)
|
||||
else:
|
||||
data = model_output.hidden_states[-1]
|
||||
|
||||
if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
|
||||
sum_embeddings = data[:, 0]
|
||||
else:
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||
token_num = torch.sum(attention_mask).item()
|
||||
|
||||
return sum_embeddings, token_num
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_embeddings(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
try:
|
||||
# Get tokenizer
|
||||
tokenizer = self.tokenizer
|
||||
ret = {"embedding": [], "token_num": 0}
|
||||
|
||||
# Based on conditions of different model_type
|
||||
model_type_dict = {
|
||||
"is_llama": "llama" in str(type(self.model)),
|
||||
"is_t5": "t5" in str(type(self.model)),
|
||||
"is_chatglm": "chatglm" in str(type(self.model)),
|
||||
"is_bert": "bert" in str(type(self.model)),
|
||||
"is_robert": "robert" in str(type(self.model)),
|
||||
}
|
||||
|
||||
if self.embed_in_truncate:
|
||||
encoding = tokenizer.batch_encode_plus(
|
||||
params["input"],
|
||||
padding=True,
|
||||
truncation="longest_first",
|
||||
return_tensors="pt",
|
||||
max_length=self.context_len,
|
||||
)
|
||||
else:
|
||||
encoding = tokenizer.batch_encode_plus(
|
||||
params["input"], padding=True, return_tensors="pt"
|
||||
)
|
||||
input_ids = encoding["input_ids"].to(self.device)
|
||||
# Check if we need attention_mask or not.
|
||||
attention_mask = input_ids != tokenizer.pad_token_id
|
||||
|
||||
base64_encode = params.get("encoding_format", None)
|
||||
|
||||
if self.embed_in_truncate:
|
||||
embedding, token_num = self.__process_embed_chunk(
|
||||
input_ids, attention_mask, **model_type_dict
|
||||
)
|
||||
if (
|
||||
not hasattr(self.model, "use_cls_pooling")
|
||||
or not self.model.use_cls_pooling
|
||||
):
|
||||
embedding = embedding / token_num
|
||||
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
||||
ret["token_num"] = token_num
|
||||
else:
|
||||
all_embeddings = []
|
||||
all_token_num = 0
|
||||
for i in range(0, input_ids.size(1), self.context_len):
|
||||
chunk_input_ids = input_ids[:, i:i + self.context_len]
|
||||
chunk_attention_mask = attention_mask[:, i:i + self.context_len]
|
||||
|
||||
# add cls token and mask to get cls embedding
|
||||
if (
|
||||
hasattr(self.model, "use_cls_pooling")
|
||||
and self.model.use_cls_pooling
|
||||
):
|
||||
cls_tokens = (
|
||||
torch.zeros(
|
||||
(chunk_input_ids.size(0), 1),
|
||||
dtype=chunk_input_ids.dtype,
|
||||
device=chunk_input_ids.device,
|
||||
)
|
||||
+ tokenizer.cls_token_id
|
||||
)
|
||||
chunk_input_ids = torch.cat(
|
||||
[cls_tokens, chunk_input_ids], dim=-1
|
||||
)
|
||||
mask = torch.ones(
|
||||
(chunk_attention_mask.size(0), 1),
|
||||
dtype=chunk_attention_mask.dtype,
|
||||
device=chunk_attention_mask.device,
|
||||
)
|
||||
chunk_attention_mask = torch.cat(
|
||||
[mask, chunk_attention_mask], dim=-1
|
||||
)
|
||||
|
||||
chunk_embeddings, token_num = self.__process_embed_chunk(
|
||||
chunk_input_ids, chunk_attention_mask, **model_type_dict
|
||||
)
|
||||
if (
|
||||
hasattr(self.model, "use_cls_pooling")
|
||||
and self.model.use_cls_pooling
|
||||
):
|
||||
all_embeddings.append(chunk_embeddings * token_num)
|
||||
else:
|
||||
all_embeddings.append(chunk_embeddings)
|
||||
all_token_num += token_num
|
||||
|
||||
all_embeddings_tensor = torch.stack(all_embeddings)
|
||||
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
|
||||
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
||||
|
||||
ret["token_num"] = all_token_num
|
||||
|
||||
if base64_encode == "base64":
|
||||
out_embeddings = self.__encode_base64(normalized_embeddings)
|
||||
else:
|
||||
out_embeddings = normalized_embeddings.tolist()
|
||||
ret["embedding"] = out_embeddings
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if self.device == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if self.device == "npu":
|
||||
torch.npu.empty_cache()
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
||||
}
|
||||
except (ValueError, RuntimeError) as e:
|
||||
ret = {
|
||||
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||
"error_code": ErrorCode.INTERNAL_ERROR,
|
||||
}
|
||||
return ret
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
self.call_ct += 1
|
||||
# context length is self.context_length
|
||||
|
|
@ -277,6 +431,15 @@ async def api_get_status(request: Request):
|
|||
return worker.get_status()
|
||||
|
||||
|
||||
@app.post("/worker_get_embeddings")
|
||||
async def api_get_embeddings(request: Request):
|
||||
params = await request.json()
|
||||
await acquire_worker_semaphore()
|
||||
embedding = worker.get_embeddings(params)
|
||||
release_worker_semaphore()
|
||||
return JSONResponse(content=embedding)
|
||||
|
||||
|
||||
@app.post("/count_token")
|
||||
async def api_count_token(request: Request):
|
||||
params = await request.json()
|
||||
|
|
@ -332,6 +495,7 @@ if __name__ == "__main__":
|
|||
help="Trust remote code (e.g., from HuggingFace) when"
|
||||
"downloading the model and tokenizer.",
|
||||
)
|
||||
parser.add_argument("--embed-in-truncate", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
worker = BigDLLLMWorker(
|
||||
|
|
@ -346,6 +510,7 @@ if __name__ == "__main__":
|
|||
args.device,
|
||||
args.no_register,
|
||||
args.trust_remote_code,
|
||||
args.embed_in_truncate,
|
||||
args.speculative,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ import numpy as np
|
|||
import os
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
_IS_VLLM_AVAILABLE = None
|
||||
|
||||
|
||||
def is_auto_gptq_available():
|
||||
|
|
@ -60,7 +64,16 @@ def is_auto_awq_available():
|
|||
|
||||
|
||||
def is_vllm_available():
|
||||
return importlib.util.find_spec("vllm") is not None
|
||||
global _IS_VLLM_AVAILABLE
|
||||
if _IS_VLLM_AVAILABLE is not None:
|
||||
return _IS_VLLM_AVAILABLE
|
||||
reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'list'])
|
||||
installed_packages = [r.decode().split(' ')[0] for r in reqs.split()]
|
||||
if 'vllm' in installed_packages:
|
||||
_IS_VLLM_AVAILABLE = True
|
||||
else:
|
||||
_IS_VLLM_AVAILABLE = False
|
||||
return _IS_VLLM_AVAILABLE
|
||||
|
||||
|
||||
def is_torch_distributed_initialized():
|
||||
|
|
|
|||
Loading…
Reference in a new issue