Add embeddings (#10931)
This commit is contained in:
parent
d649236321
commit
49ab5a2b0e
2 changed files with 179 additions and 1 deletions
|
|
@ -19,6 +19,9 @@ A model worker that executes the model based on BigDL-LLM.
|
||||||
Relies on load_model method
|
Relies on load_model method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import gc
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
|
|
@ -31,6 +34,7 @@ from fastapi.concurrency import run_in_threadpool
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
|
||||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||||
from fastchat.serve.model_worker import (
|
from fastchat.serve.model_worker import (
|
||||||
logger,
|
logger,
|
||||||
|
|
@ -63,6 +67,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
no_register: bool = False,
|
no_register: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
embed_in_truncate: bool = False,
|
||||||
speculative: bool = False,
|
speculative: bool = False,
|
||||||
stream_interval: int = 4,
|
stream_interval: int = 4,
|
||||||
):
|
):
|
||||||
|
|
@ -93,9 +98,158 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
)
|
)
|
||||||
self.stream_interval = stream_interval
|
self.stream_interval = stream_interval
|
||||||
self.context_len = get_context_length(self.model.config)
|
self.context_len = get_context_length(self.model.config)
|
||||||
|
self.embed_in_truncate = embed_in_truncate
|
||||||
if not no_register:
|
if not no_register:
|
||||||
self.init_heart_beat()
|
self.init_heart_beat()
|
||||||
|
|
||||||
|
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
|
||||||
|
if model_type_dict.get("is_bert"):
|
||||||
|
model_output = self.model(input_ids)
|
||||||
|
if model_type_dict.get("is_robert"):
|
||||||
|
data = model_output.last_hidden_state
|
||||||
|
else:
|
||||||
|
data = model_output[0]
|
||||||
|
elif model_type_dict.get("is_t5"):
|
||||||
|
model_output = self.model(input_ids, decoder_input_ids=input_ids)
|
||||||
|
data = model_output.encoder_last_hidden_state
|
||||||
|
else:
|
||||||
|
model_output = self.model(input_ids, output_hidden_states=True)
|
||||||
|
if model_type_dict.get("is_chatglm"):
|
||||||
|
data = model_output.hidden_states[-1].transpose(0, 1)
|
||||||
|
else:
|
||||||
|
data = model_output.hidden_states[-1]
|
||||||
|
|
||||||
|
if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
|
||||||
|
sum_embeddings = data[:, 0]
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||||
|
masked_embeddings = data * mask
|
||||||
|
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||||
|
token_num = torch.sum(attention_mask).item()
|
||||||
|
|
||||||
|
return sum_embeddings, token_num
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_embeddings(self, params):
|
||||||
|
self.call_ct += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get tokenizer
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
ret = {"embedding": [], "token_num": 0}
|
||||||
|
|
||||||
|
# Based on conditions of different model_type
|
||||||
|
model_type_dict = {
|
||||||
|
"is_llama": "llama" in str(type(self.model)),
|
||||||
|
"is_t5": "t5" in str(type(self.model)),
|
||||||
|
"is_chatglm": "chatglm" in str(type(self.model)),
|
||||||
|
"is_bert": "bert" in str(type(self.model)),
|
||||||
|
"is_robert": "robert" in str(type(self.model)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.embed_in_truncate:
|
||||||
|
encoding = tokenizer.batch_encode_plus(
|
||||||
|
params["input"],
|
||||||
|
padding=True,
|
||||||
|
truncation="longest_first",
|
||||||
|
return_tensors="pt",
|
||||||
|
max_length=self.context_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoding = tokenizer.batch_encode_plus(
|
||||||
|
params["input"], padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
input_ids = encoding["input_ids"].to(self.device)
|
||||||
|
# Check if we need attention_mask or not.
|
||||||
|
attention_mask = input_ids != tokenizer.pad_token_id
|
||||||
|
|
||||||
|
base64_encode = params.get("encoding_format", None)
|
||||||
|
|
||||||
|
if self.embed_in_truncate:
|
||||||
|
embedding, token_num = self.__process_embed_chunk(
|
||||||
|
input_ids, attention_mask, **model_type_dict
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not hasattr(self.model, "use_cls_pooling")
|
||||||
|
or not self.model.use_cls_pooling
|
||||||
|
):
|
||||||
|
embedding = embedding / token_num
|
||||||
|
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
||||||
|
ret["token_num"] = token_num
|
||||||
|
else:
|
||||||
|
all_embeddings = []
|
||||||
|
all_token_num = 0
|
||||||
|
for i in range(0, input_ids.size(1), self.context_len):
|
||||||
|
chunk_input_ids = input_ids[:, i:i + self.context_len]
|
||||||
|
chunk_attention_mask = attention_mask[:, i:i + self.context_len]
|
||||||
|
|
||||||
|
# add cls token and mask to get cls embedding
|
||||||
|
if (
|
||||||
|
hasattr(self.model, "use_cls_pooling")
|
||||||
|
and self.model.use_cls_pooling
|
||||||
|
):
|
||||||
|
cls_tokens = (
|
||||||
|
torch.zeros(
|
||||||
|
(chunk_input_ids.size(0), 1),
|
||||||
|
dtype=chunk_input_ids.dtype,
|
||||||
|
device=chunk_input_ids.device,
|
||||||
|
)
|
||||||
|
+ tokenizer.cls_token_id
|
||||||
|
)
|
||||||
|
chunk_input_ids = torch.cat(
|
||||||
|
[cls_tokens, chunk_input_ids], dim=-1
|
||||||
|
)
|
||||||
|
mask = torch.ones(
|
||||||
|
(chunk_attention_mask.size(0), 1),
|
||||||
|
dtype=chunk_attention_mask.dtype,
|
||||||
|
device=chunk_attention_mask.device,
|
||||||
|
)
|
||||||
|
chunk_attention_mask = torch.cat(
|
||||||
|
[mask, chunk_attention_mask], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_embeddings, token_num = self.__process_embed_chunk(
|
||||||
|
chunk_input_ids, chunk_attention_mask, **model_type_dict
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
hasattr(self.model, "use_cls_pooling")
|
||||||
|
and self.model.use_cls_pooling
|
||||||
|
):
|
||||||
|
all_embeddings.append(chunk_embeddings * token_num)
|
||||||
|
else:
|
||||||
|
all_embeddings.append(chunk_embeddings)
|
||||||
|
all_token_num += token_num
|
||||||
|
|
||||||
|
all_embeddings_tensor = torch.stack(all_embeddings)
|
||||||
|
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
|
||||||
|
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
||||||
|
|
||||||
|
ret["token_num"] = all_token_num
|
||||||
|
|
||||||
|
if base64_encode == "base64":
|
||||||
|
out_embeddings = self.__encode_base64(normalized_embeddings)
|
||||||
|
else:
|
||||||
|
out_embeddings = normalized_embeddings.tolist()
|
||||||
|
ret["embedding"] = out_embeddings
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if self.device == "xpu":
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
if self.device == "npu":
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
ret = {
|
||||||
|
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||||
|
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
||||||
|
}
|
||||||
|
except (ValueError, RuntimeError) as e:
|
||||||
|
ret = {
|
||||||
|
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
||||||
|
"error_code": ErrorCode.INTERNAL_ERROR,
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
self.call_ct += 1
|
self.call_ct += 1
|
||||||
# context length is self.context_length
|
# context length is self.context_length
|
||||||
|
|
@ -277,6 +431,15 @@ async def api_get_status(request: Request):
|
||||||
return worker.get_status()
|
return worker.get_status()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/worker_get_embeddings")
|
||||||
|
async def api_get_embeddings(request: Request):
|
||||||
|
params = await request.json()
|
||||||
|
await acquire_worker_semaphore()
|
||||||
|
embedding = worker.get_embeddings(params)
|
||||||
|
release_worker_semaphore()
|
||||||
|
return JSONResponse(content=embedding)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/count_token")
|
@app.post("/count_token")
|
||||||
async def api_count_token(request: Request):
|
async def api_count_token(request: Request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
|
|
@ -332,6 +495,7 @@ if __name__ == "__main__":
|
||||||
help="Trust remote code (e.g., from HuggingFace) when"
|
help="Trust remote code (e.g., from HuggingFace) when"
|
||||||
"downloading the model and tokenizer.",
|
"downloading the model and tokenizer.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--embed-in-truncate", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
worker = BigDLLLMWorker(
|
worker = BigDLLLMWorker(
|
||||||
|
|
@ -346,6 +510,7 @@ if __name__ == "__main__":
|
||||||
args.device,
|
args.device,
|
||||||
args.no_register,
|
args.no_register,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
|
args.embed_in_truncate,
|
||||||
args.speculative,
|
args.speculative,
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,10 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_IS_VLLM_AVAILABLE = None
|
||||||
|
|
||||||
|
|
||||||
def is_auto_gptq_available():
|
def is_auto_gptq_available():
|
||||||
|
|
@ -60,7 +64,16 @@ def is_auto_awq_available():
|
||||||
|
|
||||||
|
|
||||||
def is_vllm_available():
|
def is_vllm_available():
|
||||||
return importlib.util.find_spec("vllm") is not None
|
global _IS_VLLM_AVAILABLE
|
||||||
|
if _IS_VLLM_AVAILABLE is not None:
|
||||||
|
return _IS_VLLM_AVAILABLE
|
||||||
|
reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'list'])
|
||||||
|
installed_packages = [r.decode().split(' ')[0] for r in reqs.split()]
|
||||||
|
if 'vllm' in installed_packages:
|
||||||
|
_IS_VLLM_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
_IS_VLLM_AVAILABLE = False
|
||||||
|
return _IS_VLLM_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def is_torch_distributed_initialized():
|
def is_torch_distributed_initialized():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue