add codeshell example (#9743)
This commit is contained in:
		
							parent
							
								
									daf536fb2d
								
							
						
					
					
						commit
						be13b162fe
					
				
					 2 changed files with 326 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
# CodeShell
 | 
			
		||||
 | 
			
		||||
In this directory, you'll find how to use this codeshell server with vscode codeshell extension.
 | 
			
		||||
 | 
			
		||||
## 0. Extra Environment Preparations
 | 
			
		||||
 | 
			
		||||
Suppose you have already configured GPU environment, you will need some extra preparation
 | 
			
		||||
 | 
			
		||||
1. install extra requirements
 | 
			
		||||
    ```
 | 
			
		||||
    pip install uvicorn fastapi sse_starlette
 | 
			
		||||
    ```
 | 
			
		||||
 | 
			
		||||
2. search `codeshell` in vscode extension market, then install `CodeShell VSCode Extension` extension
 | 
			
		||||
 | 
			
		||||
3. change extension settings:
 | 
			
		||||
    - change `Code Shell: Run Env For LLMs` to `GPU with TGI toolkit`
 | 
			
		||||
    - disable `Code Shell: Auto Trigger Completion` (use `Alt + \` to trigger completion manually)
 | 
			
		||||
 | 
			
		||||
4. download WisdomShell/CodeShell-7B-Chat (don't use CodeShell-7B)
 | 
			
		||||
 | 
			
		||||
## 1. How to use this server
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python server.py [--option value]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
1. `--checkpoint-path <path>`: path to huggingface model checkpoint
 | 
			
		||||
2. `--device xpu`: enable GPU or not
 | 
			
		||||
3. `--multi-turn`: enable multi turn conversation or just support single turn conversation
 | 
			
		||||
4. `--cpu-embedding`: move Embedding layer to CPU or not
 | 
			
		||||
5. `--max-context <number>`: Clip the context length in Code Completion, it won't affect other features, set it to 99999 to disable it
 | 
			
		||||
 | 
			
		||||
## 2. Note
 | 
			
		||||
 | 
			
		||||
In my test, if use vscode remote connection to connect to a remote machine, then install extension and running this server on that remote machine, all extension features expect for Code Completion can be used.
 | 
			
		||||
 | 
			
		||||
If don't use remote conection, then install extension and running this server on local machine, Code Completion can also be used.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,288 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from argparse import ArgumentParser
 | 
			
		||||
from typing import Dict, List
 | 
			
		||||
from threading import Thread
 | 
			
		||||
import torch
 | 
			
		||||
import uvicorn
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from fastapi.responses import JSONResponse
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
# from transformers import AutoModelForCausalLM, AutoModel
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
 | 
			
		||||
from transformers.generation import GenerationConfig, TextIteratorStreamer
 | 
			
		||||
from transformers import StoppingCriteriaList, StoppingCriteria
 | 
			
		||||
from sse_starlette.sse import EventSourceResponse
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
    CORSMiddleware,
 | 
			
		||||
    allow_origins=["*"],
 | 
			
		||||
    allow_credentials=True,
 | 
			
		||||
    allow_methods=["*"],
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenerationParameters(BaseModel):
 | 
			
		||||
    max_new_tokens: int
 | 
			
		||||
    temperature: float
 | 
			
		||||
    repetition_penalty: float
 | 
			
		||||
    top_p: float
 | 
			
		||||
    do_sample: bool
 | 
			
		||||
    stop: List[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GenerationRequest(BaseModel):
 | 
			
		||||
    inputs: str
 | 
			
		||||
    parameters: GenerationParameters
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StopWordsCriteria(StoppingCriteria):
 | 
			
		||||
    """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
 | 
			
		||||
    def __init__(self, input_length, stop_words, tokenizer):
 | 
			
		||||
        self.input_length = input_length
 | 
			
		||||
        self.stop_words = stop_words
 | 
			
		||||
        self.stop_words += ["|<end|", "|end>|"]
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
 | 
			
		||||
    def __call__(self, input_ids, scores, **kwargs):
 | 
			
		||||
        """Returns true if all generated sequences contain any of the end-of-function strings."""
 | 
			
		||||
        texts =  [ self.tokenizer.decode(ids[self.input_length:]) for ids in input_ids ]
 | 
			
		||||
        dones = [ any(stop_word in text for stop_word in self.stop_words) for text in texts ]
 | 
			
		||||
        return all(dones)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/generate")
 | 
			
		||||
async def generate(request: GenerationRequest):
 | 
			
		||||
    global model, tokenizer, device, max_context
 | 
			
		||||
 | 
			
		||||
    if device == 'xpu':
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
    prompt = request.inputs
 | 
			
		||||
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
			
		||||
    input_length = len(input_ids[0])
 | 
			
		||||
 | 
			
		||||
    if input_length > max_context:
 | 
			
		||||
        tokens = list(input_ids[0])
 | 
			
		||||
        prefix_index = tokens.index(70001)   # fim_prefix
 | 
			
		||||
        middle_index = tokens.index(70002)   # fim_middle
 | 
			
		||||
        suffix_index = tokens.index(70003)   # fim_suffix
 | 
			
		||||
 | 
			
		||||
        prefix_tokens = tokens[prefix_index+1:suffix_index]
 | 
			
		||||
        suffix_tokens = tokens[suffix_index+1:middle_index]
 | 
			
		||||
        prefix_len = suffix_index - prefix_index - 1
 | 
			
		||||
        suffix_len = middle_index - suffix_index - 1
 | 
			
		||||
 | 
			
		||||
        if prefix_len + suffix_len > max_context:
 | 
			
		||||
            new_prefix_len = max_context * prefix_len // (prefix_len + suffix_len)
 | 
			
		||||
            new_suffix_len = max_context * suffix_len // (prefix_len + suffix_len)
 | 
			
		||||
            new_prefix_tokens = prefix_tokens[-new_prefix_len:]
 | 
			
		||||
            new_suffix_tokens = suffix_tokens[:new_suffix_len]
 | 
			
		||||
 | 
			
		||||
            input_ids = torch.tensor(
 | 
			
		||||
                tokens[:prefix_index+1] +
 | 
			
		||||
                new_prefix_tokens +
 | 
			
		||||
                tokens[suffix_index:suffix_index+1] +
 | 
			
		||||
                new_suffix_tokens +
 | 
			
		||||
                tokens[middle_index:]
 | 
			
		||||
            ).reshape(1, -1)
 | 
			
		||||
            input_length = len(input_ids[0])
 | 
			
		||||
            prompt = tokenizer.decode(input_ids[0])
 | 
			
		||||
 | 
			
		||||
    input_ids = input_ids.to(device)
 | 
			
		||||
 | 
			
		||||
    stopping_criteria = StoppingCriteriaList(
 | 
			
		||||
        [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    generation_kwargs = dict(stopping_criteria=stopping_criteria,
 | 
			
		||||
                             max_new_tokens=request.parameters.max_new_tokens,
 | 
			
		||||
                             temperature=request.parameters.temperature,
 | 
			
		||||
                             repetition_penalty=request.parameters.repetition_penalty,
 | 
			
		||||
                             top_p=request.parameters.top_p,
 | 
			
		||||
                             do_sample=request.parameters.do_sample)
 | 
			
		||||
 | 
			
		||||
    print('-'*80)
 | 
			
		||||
    print('input prompt:', prompt)
 | 
			
		||||
    print('input length:', input_length)
 | 
			
		||||
    print('-'*80)
 | 
			
		||||
 | 
			
		||||
    output_ids = model.generate(input_ids, **generation_kwargs)
 | 
			
		||||
    output_text = tokenizer.decode(output_ids[0])
 | 
			
		||||
 | 
			
		||||
    return JSONResponse({
 | 
			
		||||
                "generated_text": output_text[len(prompt):]
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/generate_stream")
 | 
			
		||||
async def generate_stream(request: GenerationRequest):
 | 
			
		||||
    global model, tokenizer, device, multi_turn
 | 
			
		||||
 | 
			
		||||
    if device == 'xpu':
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
    prompt = request.inputs
 | 
			
		||||
 | 
			
		||||
    if multi_turn:
 | 
			
		||||
        prompt = prompt
 | 
			
		||||
    else:
 | 
			
		||||
        # extract the last turn input
 | 
			
		||||
        human_ins = "## human"
 | 
			
		||||
        first_ins = prompt.find(human_ins)
 | 
			
		||||
        last_ins = prompt.rfind(human_ins)
 | 
			
		||||
        prompt = prompt[:first_ins] + prompt[last_ins:]
 | 
			
		||||
 | 
			
		||||
    input_ids = tokenizer(prompt, return_tensors="pt")
 | 
			
		||||
    input_length = len(input_ids['input_ids'][0])
 | 
			
		||||
    input_ids = input_ids.to(device)
 | 
			
		||||
 | 
			
		||||
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
			
		||||
    stopping_criteria = StoppingCriteriaList(
 | 
			
		||||
        [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    max_batch = 1024
 | 
			
		||||
    if input_length <= max_batch:
 | 
			
		||||
        past_key_values = None
 | 
			
		||||
    else:
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            past_key_values = None
 | 
			
		||||
            for start_pos in range(0, input_length - 1, max_batch):
 | 
			
		||||
                end_pos = min(start_pos + max_batch, input_length - 1)
 | 
			
		||||
                output = model.forward(input_ids['input_ids'][:, start_pos:end_pos],
 | 
			
		||||
                                       past_key_values=past_key_values)
 | 
			
		||||
                past_key_values = output.past_key_values
 | 
			
		||||
 | 
			
		||||
    generation_kwargs = dict(input_ids,
 | 
			
		||||
                             past_key_values=past_key_values,
 | 
			
		||||
                             streamer=streamer,
 | 
			
		||||
                             stopping_criteria=stopping_criteria,
 | 
			
		||||
                             max_new_tokens=request.parameters.max_new_tokens,
 | 
			
		||||
                             temperature=request.parameters.temperature,
 | 
			
		||||
                             repetition_penalty=request.parameters.repetition_penalty,
 | 
			
		||||
                             top_p=request.parameters.top_p,
 | 
			
		||||
                             do_sample=request.parameters.do_sample)
 | 
			
		||||
 | 
			
		||||
    print('-'*80)
 | 
			
		||||
    print('input prompt:', prompt)
 | 
			
		||||
    print('input length:', input_length)
 | 
			
		||||
    print('-'*80)
 | 
			
		||||
 | 
			
		||||
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
 | 
			
		||||
    thread.start()
 | 
			
		||||
 | 
			
		||||
    def create_response(streamer):
 | 
			
		||||
        for word in tqdm(streamer, "Generating Tokens", unit="token"):
 | 
			
		||||
            yield json.dumps({
 | 
			
		||||
                "token": {
 | 
			
		||||
                    "id": 0,
 | 
			
		||||
                    "text": word,
 | 
			
		||||
                },
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
    return EventSourceResponse(create_response(streamer), media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_args():
 | 
			
		||||
    parser = ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-c",
 | 
			
		||||
        "--checkpoint-path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="CodeShell-7B-Chat",
 | 
			
		||||
        help="Checkpoint name or path, default to %(default)r",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--device",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="cpu",
 | 
			
		||||
        help="Device name."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--server-port",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8080,
 | 
			
		||||
        help="Demo server port."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--server-name",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="127.0.0.1",
 | 
			
		||||
        help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
 | 
			
		||||
        " If you want other computers to access your server, use 0.0.0.0 instead.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--multi-turn",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Enable multi-turn chat",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--cpu-embedding",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Move Embedding layer to CPU"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--max-context",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=300,
 | 
			
		||||
        help="Max context length when using code completion",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    return args
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    args = _get_args()
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        args.checkpoint_path,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        args.checkpoint_path,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        load_in_4bit=True,
 | 
			
		||||
        cpu_embedding=args.cpu_embedding
 | 
			
		||||
    ).eval()
 | 
			
		||||
 | 
			
		||||
    device = args.device
 | 
			
		||||
    multi_turn = args.multi_turn
 | 
			
		||||
    max_context = args.max_context
 | 
			
		||||
 | 
			
		||||
    if device == 'xpu':
 | 
			
		||||
        import intel_extension_for_pytorch as ipex
 | 
			
		||||
 | 
			
		||||
    model = model.to(device)
 | 
			
		||||
 | 
			
		||||
    model.generation_config = GenerationConfig.from_pretrained(
 | 
			
		||||
        args.checkpoint_path,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        resume_download=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
 | 
			
		||||
		Loading…
	
		Reference in a new issue