Add embeddings (#10931)

This commit is contained in:
Guancheng Fu 2024-05-07 09:07:02 +08:00 committed by GitHub
parent d649236321
commit 49ab5a2b0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 179 additions and 1 deletions

View file

@ -19,6 +19,9 @@ A model worker that executes the model based on BigDL-LLM.
Relies on load_model method
"""
import torch
import torch.nn.functional as F
import gc
import argparse
import asyncio
import atexit
@ -31,6 +34,7 @@ from fastapi.concurrency import run_in_threadpool
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.model_worker import (
logger,
@ -63,6 +67,7 @@ class BigDLLLMWorker(BaseModelWorker):
device: str = "cpu",
no_register: bool = False,
trust_remote_code: bool = False,
embed_in_truncate: bool = False,
speculative: bool = False,
stream_interval: int = 4,
):
@ -93,9 +98,158 @@ class BigDLLLMWorker(BaseModelWorker):
)
self.stream_interval = stream_interval
self.context_len = get_context_length(self.model.config)
self.embed_in_truncate = embed_in_truncate
if not no_register:
self.init_heart_beat()
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
if model_type_dict.get("is_bert"):
model_output = self.model(input_ids)
if model_type_dict.get("is_robert"):
data = model_output.last_hidden_state
else:
data = model_output[0]
elif model_type_dict.get("is_t5"):
model_output = self.model(input_ids, decoder_input_ids=input_ids)
data = model_output.encoder_last_hidden_state
else:
model_output = self.model(input_ids, output_hidden_states=True)
if model_type_dict.get("is_chatglm"):
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
sum_embeddings = data[:, 0]
else:
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
token_num = torch.sum(attention_mask).item()
return sum_embeddings, token_num
@torch.inference_mode()
def get_embeddings(self, params):
self.call_ct += 1
try:
# Get tokenizer
tokenizer = self.tokenizer
ret = {"embedding": [], "token_num": 0}
# Based on conditions of different model_type
model_type_dict = {
"is_llama": "llama" in str(type(self.model)),
"is_t5": "t5" in str(type(self.model)),
"is_chatglm": "chatglm" in str(type(self.model)),
"is_bert": "bert" in str(type(self.model)),
"is_robert": "robert" in str(type(self.model)),
}
if self.embed_in_truncate:
encoding = tokenizer.batch_encode_plus(
params["input"],
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.context_len,
)
else:
encoding = tokenizer.batch_encode_plus(
params["input"], padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
# Check if we need attention_mask or not.
attention_mask = input_ids != tokenizer.pad_token_id
base64_encode = params.get("encoding_format", None)
if self.embed_in_truncate:
embedding, token_num = self.__process_embed_chunk(
input_ids, attention_mask, **model_type_dict
)
if (
not hasattr(self.model, "use_cls_pooling")
or not self.model.use_cls_pooling
):
embedding = embedding / token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = token_num
else:
all_embeddings = []
all_token_num = 0
for i in range(0, input_ids.size(1), self.context_len):
chunk_input_ids = input_ids[:, i:i + self.context_len]
chunk_attention_mask = attention_mask[:, i:i + self.context_len]
# add cls token and mask to get cls embedding
if (
hasattr(self.model, "use_cls_pooling")
and self.model.use_cls_pooling
):
cls_tokens = (
torch.zeros(
(chunk_input_ids.size(0), 1),
dtype=chunk_input_ids.dtype,
device=chunk_input_ids.device,
)
+ tokenizer.cls_token_id
)
chunk_input_ids = torch.cat(
[cls_tokens, chunk_input_ids], dim=-1
)
mask = torch.ones(
(chunk_attention_mask.size(0), 1),
dtype=chunk_attention_mask.dtype,
device=chunk_attention_mask.device,
)
chunk_attention_mask = torch.cat(
[mask, chunk_attention_mask], dim=-1
)
chunk_embeddings, token_num = self.__process_embed_chunk(
chunk_input_ids, chunk_attention_mask, **model_type_dict
)
if (
hasattr(self.model, "use_cls_pooling")
and self.model.use_cls_pooling
):
all_embeddings.append(chunk_embeddings * token_num)
else:
all_embeddings.append(chunk_embeddings)
all_token_num += token_num
all_embeddings_tensor = torch.stack(all_embeddings)
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = all_token_num
if base64_encode == "base64":
out_embeddings = self.__encode_base64(normalized_embeddings)
else:
out_embeddings = normalized_embeddings.tolist()
ret["embedding"] = out_embeddings
gc.collect()
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
if self.device == "npu":
torch.npu.empty_cache()
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return ret
def generate_stream_gate(self, params):
self.call_ct += 1
# context length is self.context_length
@ -277,6 +431,15 @@ async def api_get_status(request: Request):
return worker.get_status()
@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
params = await request.json()
await acquire_worker_semaphore()
embedding = worker.get_embeddings(params)
release_worker_semaphore()
return JSONResponse(content=embedding)
@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
@ -332,6 +495,7 @@ if __name__ == "__main__":
help="Trust remote code (e.g., from HuggingFace) when"
"downloading the model and tokenizer.",
)
parser.add_argument("--embed-in-truncate", action="store_true")
args = parser.parse_args()
worker = BigDLLLMWorker(
@ -346,6 +510,7 @@ if __name__ == "__main__":
args.device,
args.no_register,
args.trust_remote_code,
args.embed_in_truncate,
args.speculative,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View file

@ -49,6 +49,10 @@ import numpy as np
import os
from ipex_llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
import subprocess
import sys
_IS_VLLM_AVAILABLE = None
def is_auto_gptq_available():
@ -60,7 +64,16 @@ def is_auto_awq_available():
def is_vllm_available():
return importlib.util.find_spec("vllm") is not None
global _IS_VLLM_AVAILABLE
if _IS_VLLM_AVAILABLE is not None:
return _IS_VLLM_AVAILABLE
reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'list'])
installed_packages = [r.decode().split(' ')[0] for r in reqs.split()]
if 'vllm' in installed_packages:
_IS_VLLM_AVAILABLE = True
else:
_IS_VLLM_AVAILABLE = False
return _IS_VLLM_AVAILABLE
def is_torch_distributed_initialized():