add codeshell example (#9743)

This commit is contained in:
Yishuo Wang 2023-12-25 10:54:01 +08:00 committed by GitHub
parent daf536fb2d
commit be13b162fe
2 changed files with 326 additions and 0 deletions

View file

@ -0,0 +1,38 @@
# CodeShell
In this directory, you'll find how to use this codeshell server with vscode codeshell extension.
## 0. Extra Environment Preparations
Suppose you have already configured GPU environment, you will need some extra preparation
1. install extra requirements
```
pip install uvicorn fastapi sse_starlette
```
2. search `codeshell` in vscode extension market, then install `CodeShell VSCode Extension` extension
3. change extension settings:
- change `Code Shell: Run Env For LLMs` to `GPU with TGI toolkit`
- disable `Code Shell: Auto Trigger Completion` (use `Alt + \` to trigger completion manually)
4. download WisdomShell/CodeShell-7B-Chat (don't use CodeShell-7B)
## 1. How to use this server
```
python server.py [--option value]
```
1. `--checkpoint-path <path>`: path to huggingface model checkpoint
2. `--device xpu`: enable GPU or not
3. `--multi-turn`: enable multi turn conversation or just support single turn conversation
4. `--cpu-embedding`: move Embedding layer to CPU or not
5. `--max-context <number>`: Clip the context length in Code Completion, it won't affect other features, set it to 99999 to disable it
## 2. Note
In my test, if use vscode remote connection to connect to a remote machine, then install extension and running this server on that remote machine, all extension features expect for Code Completion can be used.
If don't use remote conection, then install extension and running this server on local machine, Code Completion can also be used.

View file

@ -0,0 +1,288 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import json
from tqdm import tqdm
from argparse import ArgumentParser
from typing import Dict, List
from threading import Thread
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer
# from transformers import AutoModelForCausalLM, AutoModel
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
from transformers.generation import GenerationConfig, TextIteratorStreamer
from transformers import StoppingCriteriaList, StoppingCriteria
from sse_starlette.sse import EventSourceResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class GenerationParameters(BaseModel):
max_new_tokens: int
temperature: float
repetition_penalty: float
top_p: float
do_sample: bool
stop: List[str]
class GenerationRequest(BaseModel):
inputs: str
parameters: GenerationParameters
class StopWordsCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, input_length, stop_words, tokenizer):
self.input_length = input_length
self.stop_words = stop_words
self.stop_words += ["|<end|", "|end>|"]
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
texts = [ self.tokenizer.decode(ids[self.input_length:]) for ids in input_ids ]
dones = [ any(stop_word in text for stop_word in self.stop_words) for text in texts ]
return all(dones)
@app.post("/generate")
async def generate(request: GenerationRequest):
global model, tokenizer, device, max_context
if device == 'xpu':
torch.xpu.empty_cache()
prompt = request.inputs
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_length = len(input_ids[0])
if input_length > max_context:
tokens = list(input_ids[0])
prefix_index = tokens.index(70001) # fim_prefix
middle_index = tokens.index(70002) # fim_middle
suffix_index = tokens.index(70003) # fim_suffix
prefix_tokens = tokens[prefix_index+1:suffix_index]
suffix_tokens = tokens[suffix_index+1:middle_index]
prefix_len = suffix_index - prefix_index - 1
suffix_len = middle_index - suffix_index - 1
if prefix_len + suffix_len > max_context:
new_prefix_len = max_context * prefix_len // (prefix_len + suffix_len)
new_suffix_len = max_context * suffix_len // (prefix_len + suffix_len)
new_prefix_tokens = prefix_tokens[-new_prefix_len:]
new_suffix_tokens = suffix_tokens[:new_suffix_len]
input_ids = torch.tensor(
tokens[:prefix_index+1] +
new_prefix_tokens +
tokens[suffix_index:suffix_index+1] +
new_suffix_tokens +
tokens[middle_index:]
).reshape(1, -1)
input_length = len(input_ids[0])
prompt = tokenizer.decode(input_ids[0])
input_ids = input_ids.to(device)
stopping_criteria = StoppingCriteriaList(
[ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
)
generation_kwargs = dict(stopping_criteria=stopping_criteria,
max_new_tokens=request.parameters.max_new_tokens,
temperature=request.parameters.temperature,
repetition_penalty=request.parameters.repetition_penalty,
top_p=request.parameters.top_p,
do_sample=request.parameters.do_sample)
print('-'*80)
print('input prompt:', prompt)
print('input length:', input_length)
print('-'*80)
output_ids = model.generate(input_ids, **generation_kwargs)
output_text = tokenizer.decode(output_ids[0])
return JSONResponse({
"generated_text": output_text[len(prompt):]
})
@app.post("/generate_stream")
async def generate_stream(request: GenerationRequest):
global model, tokenizer, device, multi_turn
if device == 'xpu':
torch.xpu.empty_cache()
prompt = request.inputs
if multi_turn:
prompt = prompt
else:
# extract the last turn input
human_ins = "## human"
first_ins = prompt.find(human_ins)
last_ins = prompt.rfind(human_ins)
prompt = prompt[:first_ins] + prompt[last_ins:]
input_ids = tokenizer(prompt, return_tensors="pt")
input_length = len(input_ids['input_ids'][0])
input_ids = input_ids.to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
stopping_criteria = StoppingCriteriaList(
[ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
)
max_batch = 1024
if input_length <= max_batch:
past_key_values = None
else:
with torch.inference_mode():
past_key_values = None
for start_pos in range(0, input_length - 1, max_batch):
end_pos = min(start_pos + max_batch, input_length - 1)
output = model.forward(input_ids['input_ids'][:, start_pos:end_pos],
past_key_values=past_key_values)
past_key_values = output.past_key_values
generation_kwargs = dict(input_ids,
past_key_values=past_key_values,
streamer=streamer,
stopping_criteria=stopping_criteria,
max_new_tokens=request.parameters.max_new_tokens,
temperature=request.parameters.temperature,
repetition_penalty=request.parameters.repetition_penalty,
top_p=request.parameters.top_p,
do_sample=request.parameters.do_sample)
print('-'*80)
print('input prompt:', prompt)
print('input length:', input_length)
print('-'*80)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
def create_response(streamer):
for word in tqdm(streamer, "Generating Tokens", unit="token"):
yield json.dumps({
"token": {
"id": 0,
"text": word,
},
})
return EventSourceResponse(create_response(streamer), media_type="text/event-stream")
def _get_args():
parser = ArgumentParser()
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
default="CodeShell-7B-Chat",
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device name."
)
parser.add_argument(
"--server-port",
type=int,
default=8080,
help="Demo server port."
)
parser.add_argument(
"--server-name",
type=str,
default="127.0.0.1",
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
" If you want other computers to access your server, use 0.0.0.0 instead.",
)
parser.add_argument(
"--multi-turn",
action="store_true",
help="Enable multi-turn chat",
)
parser.add_argument(
"--cpu-embedding",
action="store_true",
help="Move Embedding layer to CPU"
)
parser.add_argument(
"--max-context",
type=int,
default=300,
help="Max context length when using code completion",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _get_args()
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
trust_remote_code=True,
load_in_4bit=True,
cpu_embedding=args.cpu_embedding
).eval()
device = args.device
multi_turn = args.multi_turn
max_context = args.max_context
if device == 'xpu':
import intel_extension_for_pytorch as ipex
model = model.to(device)
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path,
trust_remote_code=True,
resume_download=True,
)
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)