From 49ab5a2b0e08506371289c0bcc2b0e795c9e6d21 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Tue, 7 May 2024 09:07:02 +0800 Subject: [PATCH] Add embeddings (#10931) --- .../serving/fastchat/ipex_llm_worker.py | 165 ++++++++++++++++++ .../llm/src/ipex_llm/transformers/convert.py | 15 +- 2 files changed, 179 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py index 2c234645..cf08e928 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py +++ b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py @@ -19,6 +19,9 @@ A model worker that executes the model based on BigDL-LLM. Relies on load_model method """ +import torch +import torch.nn.functional as F +import gc import argparse import asyncio import atexit @@ -31,6 +34,7 @@ from fastapi.concurrency import run_in_threadpool from fastapi.responses import StreamingResponse, JSONResponse import uvicorn +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.serve.model_worker import ( logger, @@ -63,6 +67,7 @@ class BigDLLLMWorker(BaseModelWorker): device: str = "cpu", no_register: bool = False, trust_remote_code: bool = False, + embed_in_truncate: bool = False, speculative: bool = False, stream_interval: int = 4, ): @@ -93,9 +98,158 @@ class BigDLLLMWorker(BaseModelWorker): ) self.stream_interval = stream_interval self.context_len = get_context_length(self.model.config) + self.embed_in_truncate = embed_in_truncate if not no_register: self.init_heart_beat() + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + + if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: + sum_embeddings = data[:, 0] + else: + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + # Get tokenizer + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + # Based on conditions of different model_type + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + # Check if we need attention_mask or not. + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + embedding, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + if ( + not hasattr(self.model, "use_cls_pooling") + or not self.model.use_cls_pooling + ): + embedding = embedding / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i:i + self.context_len] + chunk_attention_mask = attention_mask[:, i:i + self.context_len] + + # add cls token and mask to get cls embedding + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + cls_tokens = ( + torch.zeros( + (chunk_input_ids.size(0), 1), + dtype=chunk_input_ids.dtype, + device=chunk_input_ids.device, + ) + + tokenizer.cls_token_id + ) + chunk_input_ids = torch.cat( + [cls_tokens, chunk_input_ids], dim=-1 + ) + mask = torch.ones( + (chunk_attention_mask.size(0), 1), + dtype=chunk_attention_mask.dtype, + device=chunk_attention_mask.device, + ) + chunk_attention_mask = torch.cat( + [mask, chunk_attention_mask], dim=-1 + ) + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + all_embeddings.append(chunk_embeddings * token_num) + else: + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + def generate_stream_gate(self, params): self.call_ct += 1 # context length is self.context_length @@ -277,6 +431,15 @@ async def api_get_status(request: Request): return worker.get_status() +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + @app.post("/count_token") async def api_count_token(request: Request): params = await request.json() @@ -332,6 +495,7 @@ if __name__ == "__main__": help="Trust remote code (e.g., from HuggingFace) when" "downloading the model and tokenizer.", ) + parser.add_argument("--embed-in-truncate", action="store_true") args = parser.parse_args() worker = BigDLLLMWorker( @@ -346,6 +510,7 @@ if __name__ == "__main__": args.device, args.no_register, args.trust_remote_code, + args.embed_in_truncate, args.speculative, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index a3573bc4..553e4e6c 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -49,6 +49,10 @@ import numpy as np import os from ipex_llm.utils.common import invalidInputError from typing import List, Optional, Tuple, Union +import subprocess +import sys + +_IS_VLLM_AVAILABLE = None def is_auto_gptq_available(): @@ -60,7 +64,16 @@ def is_auto_awq_available(): def is_vllm_available(): - return importlib.util.find_spec("vllm") is not None + global _IS_VLLM_AVAILABLE + if _IS_VLLM_AVAILABLE is not None: + return _IS_VLLM_AVAILABLE + reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'list']) + installed_packages = [r.decode().split(' ')[0] for r in reqs.split()] + if 'vllm' in installed_packages: + _IS_VLLM_AVAILABLE = True + else: + _IS_VLLM_AVAILABLE = False + return _IS_VLLM_AVAILABLE def is_torch_distributed_initialized():