diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/README.md new file mode 100644 index 00000000..e12e8163 --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/README.md @@ -0,0 +1,38 @@ +# CodeShell + +In this directory, you'll find how to use this codeshell server with vscode codeshell extension. + +## 0. Extra Environment Preparations + +Suppose you have already configured GPU environment, you will need some extra preparation + +1. install extra requirements + ``` + pip install uvicorn fastapi sse_starlette + ``` + +2. search `codeshell` in vscode extension market, then install `CodeShell VSCode Extension` extension + +3. change extension settings: + - change `Code Shell: Run Env For LLMs` to `GPU with TGI toolkit` + - disable `Code Shell: Auto Trigger Completion` (use `Alt + \` to trigger completion manually) + +4. download WisdomShell/CodeShell-7B-Chat (don't use CodeShell-7B) + +## 1. How to use this server + +``` +python server.py [--option value] +``` + +1. `--checkpoint-path `: path to huggingface model checkpoint +2. `--device xpu`: enable GPU or not +3. `--multi-turn`: enable multi turn conversation or just support single turn conversation +4. `--cpu-embedding`: move Embedding layer to CPU or not +5. `--max-context `: Clip the context length in Code Completion, it won't affect other features, set it to 99999 to disable it + +## 2. Note + +In my test, if use vscode remote connection to connect to a remote machine, then install extension and running this server on that remote machine, all extension features expect for Code Completion can be used. + +If don't use remote conection, then install extension and running this server on local machine, Code Completion can also be used. diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/server.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/server.py new file mode 100644 index 00000000..5a77812c --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/codeshell/server.py @@ -0,0 +1,288 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import json +from tqdm import tqdm +from argparse import ArgumentParser +from typing import Dict, List +from threading import Thread +import torch +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from transformers import AutoTokenizer +# from transformers import AutoModelForCausalLM, AutoModel +from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel +from transformers.generation import GenerationConfig, TextIteratorStreamer +from transformers import StoppingCriteriaList, StoppingCriteria +from sse_starlette.sse import EventSourceResponse + + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class GenerationParameters(BaseModel): + max_new_tokens: int + temperature: float + repetition_penalty: float + top_p: float + do_sample: bool + stop: List[str] + + +class GenerationRequest(BaseModel): + inputs: str + parameters: GenerationParameters + + +class StopWordsCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" + def __init__(self, input_length, stop_words, tokenizer): + self.input_length = input_length + self.stop_words = stop_words + self.stop_words += ["||"] + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the end-of-function strings.""" + texts = [ self.tokenizer.decode(ids[self.input_length:]) for ids in input_ids ] + dones = [ any(stop_word in text for stop_word in self.stop_words) for text in texts ] + return all(dones) + + +@app.post("/generate") +async def generate(request: GenerationRequest): + global model, tokenizer, device, max_context + + if device == 'xpu': + torch.xpu.empty_cache() + + prompt = request.inputs + input_ids = tokenizer.encode(prompt, return_tensors="pt") + input_length = len(input_ids[0]) + + if input_length > max_context: + tokens = list(input_ids[0]) + prefix_index = tokens.index(70001) # fim_prefix + middle_index = tokens.index(70002) # fim_middle + suffix_index = tokens.index(70003) # fim_suffix + + prefix_tokens = tokens[prefix_index+1:suffix_index] + suffix_tokens = tokens[suffix_index+1:middle_index] + prefix_len = suffix_index - prefix_index - 1 + suffix_len = middle_index - suffix_index - 1 + + if prefix_len + suffix_len > max_context: + new_prefix_len = max_context * prefix_len // (prefix_len + suffix_len) + new_suffix_len = max_context * suffix_len // (prefix_len + suffix_len) + new_prefix_tokens = prefix_tokens[-new_prefix_len:] + new_suffix_tokens = suffix_tokens[:new_suffix_len] + + input_ids = torch.tensor( + tokens[:prefix_index+1] + + new_prefix_tokens + + tokens[suffix_index:suffix_index+1] + + new_suffix_tokens + + tokens[middle_index:] + ).reshape(1, -1) + input_length = len(input_ids[0]) + prompt = tokenizer.decode(input_ids[0]) + + input_ids = input_ids.to(device) + + stopping_criteria = StoppingCriteriaList( + [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ] + ) + + generation_kwargs = dict(stopping_criteria=stopping_criteria, + max_new_tokens=request.parameters.max_new_tokens, + temperature=request.parameters.temperature, + repetition_penalty=request.parameters.repetition_penalty, + top_p=request.parameters.top_p, + do_sample=request.parameters.do_sample) + + print('-'*80) + print('input prompt:', prompt) + print('input length:', input_length) + print('-'*80) + + output_ids = model.generate(input_ids, **generation_kwargs) + output_text = tokenizer.decode(output_ids[0]) + + return JSONResponse({ + "generated_text": output_text[len(prompt):] + }) + + +@app.post("/generate_stream") +async def generate_stream(request: GenerationRequest): + global model, tokenizer, device, multi_turn + + if device == 'xpu': + torch.xpu.empty_cache() + + prompt = request.inputs + + if multi_turn: + prompt = prompt + else: + # extract the last turn input + human_ins = "## human" + first_ins = prompt.find(human_ins) + last_ins = prompt.rfind(human_ins) + prompt = prompt[:first_ins] + prompt[last_ins:] + + input_ids = tokenizer(prompt, return_tensors="pt") + input_length = len(input_ids['input_ids'][0]) + input_ids = input_ids.to(device) + + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) + stopping_criteria = StoppingCriteriaList( + [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ] + ) + + max_batch = 1024 + if input_length <= max_batch: + past_key_values = None + else: + with torch.inference_mode(): + past_key_values = None + for start_pos in range(0, input_length - 1, max_batch): + end_pos = min(start_pos + max_batch, input_length - 1) + output = model.forward(input_ids['input_ids'][:, start_pos:end_pos], + past_key_values=past_key_values) + past_key_values = output.past_key_values + + generation_kwargs = dict(input_ids, + past_key_values=past_key_values, + streamer=streamer, + stopping_criteria=stopping_criteria, + max_new_tokens=request.parameters.max_new_tokens, + temperature=request.parameters.temperature, + repetition_penalty=request.parameters.repetition_penalty, + top_p=request.parameters.top_p, + do_sample=request.parameters.do_sample) + + print('-'*80) + print('input prompt:', prompt) + print('input length:', input_length) + print('-'*80) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + def create_response(streamer): + for word in tqdm(streamer, "Generating Tokens", unit="token"): + yield json.dumps({ + "token": { + "id": 0, + "text": word, + }, + }) + + return EventSourceResponse(create_response(streamer), media_type="text/event-stream") + + +def _get_args(): + parser = ArgumentParser() + parser.add_argument( + "-c", + "--checkpoint-path", + type=str, + default="CodeShell-7B-Chat", + help="Checkpoint name or path, default to %(default)r", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Device name." + ) + parser.add_argument( + "--server-port", + type=int, + default=8080, + help="Demo server port." + ) + parser.add_argument( + "--server-name", + type=str, + default="127.0.0.1", + help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer." + " If you want other computers to access your server, use 0.0.0.0 instead.", + ) + parser.add_argument( + "--multi-turn", + action="store_true", + help="Enable multi-turn chat", + ) + parser.add_argument( + "--cpu-embedding", + action="store_true", + help="Move Embedding layer to CPU" + ) + parser.add_argument( + "--max-context", + type=int, + default=300, + help="Max context length when using code completion", + ) + + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = _get_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + load_in_4bit=True, + cpu_embedding=args.cpu_embedding + ).eval() + + device = args.device + multi_turn = args.multi_turn + max_context = args.max_context + + if device == 'xpu': + import intel_extension_for_pytorch as ipex + + model = model.to(device) + + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ) + + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)