add codeshell example (#9743)
This commit is contained in:
		
							parent
							
								
									daf536fb2d
								
							
						
					
					
						commit
						be13b162fe
					
				
					 2 changed files with 326 additions and 0 deletions
				
			
		| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					# CodeShell
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In this directory, you'll find how to use this codeshell server with vscode codeshell extension.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 0. Extra Environment Preparations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Suppose you have already configured GPU environment, you will need some extra preparation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. install extra requirements
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					    pip install uvicorn fastapi sse_starlette
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					2. search `codeshell` in vscode extension market, then install `CodeShell VSCode Extension` extension
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					3. change extension settings:
 | 
				
			||||||
 | 
					    - change `Code Shell: Run Env For LLMs` to `GPU with TGI toolkit`
 | 
				
			||||||
 | 
					    - disable `Code Shell: Auto Trigger Completion` (use `Alt + \` to trigger completion manually)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					4. download WisdomShell/CodeShell-7B-Chat (don't use CodeShell-7B)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 1. How to use this server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					python server.py [--option value]
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. `--checkpoint-path <path>`: path to huggingface model checkpoint
 | 
				
			||||||
 | 
					2. `--device xpu`: enable GPU or not
 | 
				
			||||||
 | 
					3. `--multi-turn`: enable multi turn conversation or just support single turn conversation
 | 
				
			||||||
 | 
					4. `--cpu-embedding`: move Embedding layer to CPU or not
 | 
				
			||||||
 | 
					5. `--max-context <number>`: Clip the context length in Code Completion, it won't affect other features, set it to 99999 to disable it
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 2. Note
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In my test, if use vscode remote connection to connect to a remote machine, then install extension and running this server on that remote machine, all extension features expect for Code Completion can be used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If don't use remote conection, then install extension and running this server on local machine, Code Completion can also be used.
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,288 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					from argparse import ArgumentParser
 | 
				
			||||||
 | 
					from typing import Dict, List
 | 
				
			||||||
 | 
					from threading import Thread
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import uvicorn
 | 
				
			||||||
 | 
					from fastapi import FastAPI
 | 
				
			||||||
 | 
					from fastapi.responses import JSONResponse
 | 
				
			||||||
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
 | 
					from pydantic import BaseModel
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					# from transformers import AutoModelForCausalLM, AutoModel
 | 
				
			||||||
 | 
					from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
 | 
				
			||||||
 | 
					from transformers.generation import GenerationConfig, TextIteratorStreamer
 | 
				
			||||||
 | 
					from transformers import StoppingCriteriaList, StoppingCriteria
 | 
				
			||||||
 | 
					from sse_starlette.sse import EventSourceResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app.add_middleware(
 | 
				
			||||||
 | 
					    CORSMiddleware,
 | 
				
			||||||
 | 
					    allow_origins=["*"],
 | 
				
			||||||
 | 
					    allow_credentials=True,
 | 
				
			||||||
 | 
					    allow_methods=["*"],
 | 
				
			||||||
 | 
					    allow_headers=["*"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationParameters(BaseModel):
 | 
				
			||||||
 | 
					    max_new_tokens: int
 | 
				
			||||||
 | 
					    temperature: float
 | 
				
			||||||
 | 
					    repetition_penalty: float
 | 
				
			||||||
 | 
					    top_p: float
 | 
				
			||||||
 | 
					    do_sample: bool
 | 
				
			||||||
 | 
					    stop: List[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationRequest(BaseModel):
 | 
				
			||||||
 | 
					    inputs: str
 | 
				
			||||||
 | 
					    parameters: GenerationParameters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StopWordsCriteria(StoppingCriteria):
 | 
				
			||||||
 | 
					    """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
 | 
				
			||||||
 | 
					    def __init__(self, input_length, stop_words, tokenizer):
 | 
				
			||||||
 | 
					        self.input_length = input_length
 | 
				
			||||||
 | 
					        self.stop_words = stop_words
 | 
				
			||||||
 | 
					        self.stop_words += ["|<end|", "|end>|"]
 | 
				
			||||||
 | 
					        self.tokenizer = tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, input_ids, scores, **kwargs):
 | 
				
			||||||
 | 
					        """Returns true if all generated sequences contain any of the end-of-function strings."""
 | 
				
			||||||
 | 
					        texts =  [ self.tokenizer.decode(ids[self.input_length:]) for ids in input_ids ]
 | 
				
			||||||
 | 
					        dones = [ any(stop_word in text for stop_word in self.stop_words) for text in texts ]
 | 
				
			||||||
 | 
					        return all(dones)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/generate")
 | 
				
			||||||
 | 
					async def generate(request: GenerationRequest):
 | 
				
			||||||
 | 
					    global model, tokenizer, device, max_context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if device == 'xpu':
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prompt = request.inputs
 | 
				
			||||||
 | 
					    input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					    input_length = len(input_ids[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if input_length > max_context:
 | 
				
			||||||
 | 
					        tokens = list(input_ids[0])
 | 
				
			||||||
 | 
					        prefix_index = tokens.index(70001)   # fim_prefix
 | 
				
			||||||
 | 
					        middle_index = tokens.index(70002)   # fim_middle
 | 
				
			||||||
 | 
					        suffix_index = tokens.index(70003)   # fim_suffix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        prefix_tokens = tokens[prefix_index+1:suffix_index]
 | 
				
			||||||
 | 
					        suffix_tokens = tokens[suffix_index+1:middle_index]
 | 
				
			||||||
 | 
					        prefix_len = suffix_index - prefix_index - 1
 | 
				
			||||||
 | 
					        suffix_len = middle_index - suffix_index - 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if prefix_len + suffix_len > max_context:
 | 
				
			||||||
 | 
					            new_prefix_len = max_context * prefix_len // (prefix_len + suffix_len)
 | 
				
			||||||
 | 
					            new_suffix_len = max_context * suffix_len // (prefix_len + suffix_len)
 | 
				
			||||||
 | 
					            new_prefix_tokens = prefix_tokens[-new_prefix_len:]
 | 
				
			||||||
 | 
					            new_suffix_tokens = suffix_tokens[:new_suffix_len]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            input_ids = torch.tensor(
 | 
				
			||||||
 | 
					                tokens[:prefix_index+1] +
 | 
				
			||||||
 | 
					                new_prefix_tokens +
 | 
				
			||||||
 | 
					                tokens[suffix_index:suffix_index+1] +
 | 
				
			||||||
 | 
					                new_suffix_tokens +
 | 
				
			||||||
 | 
					                tokens[middle_index:]
 | 
				
			||||||
 | 
					            ).reshape(1, -1)
 | 
				
			||||||
 | 
					            input_length = len(input_ids[0])
 | 
				
			||||||
 | 
					            prompt = tokenizer.decode(input_ids[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_ids = input_ids.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stopping_criteria = StoppingCriteriaList(
 | 
				
			||||||
 | 
					        [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generation_kwargs = dict(stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					                             max_new_tokens=request.parameters.max_new_tokens,
 | 
				
			||||||
 | 
					                             temperature=request.parameters.temperature,
 | 
				
			||||||
 | 
					                             repetition_penalty=request.parameters.repetition_penalty,
 | 
				
			||||||
 | 
					                             top_p=request.parameters.top_p,
 | 
				
			||||||
 | 
					                             do_sample=request.parameters.do_sample)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print('-'*80)
 | 
				
			||||||
 | 
					    print('input prompt:', prompt)
 | 
				
			||||||
 | 
					    print('input length:', input_length)
 | 
				
			||||||
 | 
					    print('-'*80)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output_ids = model.generate(input_ids, **generation_kwargs)
 | 
				
			||||||
 | 
					    output_text = tokenizer.decode(output_ids[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return JSONResponse({
 | 
				
			||||||
 | 
					                "generated_text": output_text[len(prompt):]
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/generate_stream")
 | 
				
			||||||
 | 
					async def generate_stream(request: GenerationRequest):
 | 
				
			||||||
 | 
					    global model, tokenizer, device, multi_turn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if device == 'xpu':
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prompt = request.inputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if multi_turn:
 | 
				
			||||||
 | 
					        prompt = prompt
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # extract the last turn input
 | 
				
			||||||
 | 
					        human_ins = "## human"
 | 
				
			||||||
 | 
					        first_ins = prompt.find(human_ins)
 | 
				
			||||||
 | 
					        last_ins = prompt.rfind(human_ins)
 | 
				
			||||||
 | 
					        prompt = prompt[:first_ins] + prompt[last_ins:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_ids = tokenizer(prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					    input_length = len(input_ids['input_ids'][0])
 | 
				
			||||||
 | 
					    input_ids = input_ids.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
				
			||||||
 | 
					    stopping_criteria = StoppingCriteriaList(
 | 
				
			||||||
 | 
					        [ StopWordsCriteria(input_length, request.parameters.stop, tokenizer) ]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_batch = 1024
 | 
				
			||||||
 | 
					    if input_length <= max_batch:
 | 
				
			||||||
 | 
					        past_key_values = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        with torch.inference_mode():
 | 
				
			||||||
 | 
					            past_key_values = None
 | 
				
			||||||
 | 
					            for start_pos in range(0, input_length - 1, max_batch):
 | 
				
			||||||
 | 
					                end_pos = min(start_pos + max_batch, input_length - 1)
 | 
				
			||||||
 | 
					                output = model.forward(input_ids['input_ids'][:, start_pos:end_pos],
 | 
				
			||||||
 | 
					                                       past_key_values=past_key_values)
 | 
				
			||||||
 | 
					                past_key_values = output.past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generation_kwargs = dict(input_ids,
 | 
				
			||||||
 | 
					                             past_key_values=past_key_values,
 | 
				
			||||||
 | 
					                             streamer=streamer,
 | 
				
			||||||
 | 
					                             stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					                             max_new_tokens=request.parameters.max_new_tokens,
 | 
				
			||||||
 | 
					                             temperature=request.parameters.temperature,
 | 
				
			||||||
 | 
					                             repetition_penalty=request.parameters.repetition_penalty,
 | 
				
			||||||
 | 
					                             top_p=request.parameters.top_p,
 | 
				
			||||||
 | 
					                             do_sample=request.parameters.do_sample)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print('-'*80)
 | 
				
			||||||
 | 
					    print('input prompt:', prompt)
 | 
				
			||||||
 | 
					    print('input length:', input_length)
 | 
				
			||||||
 | 
					    print('-'*80)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    thread = Thread(target=model.generate, kwargs=generation_kwargs)
 | 
				
			||||||
 | 
					    thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_response(streamer):
 | 
				
			||||||
 | 
					        for word in tqdm(streamer, "Generating Tokens", unit="token"):
 | 
				
			||||||
 | 
					            yield json.dumps({
 | 
				
			||||||
 | 
					                "token": {
 | 
				
			||||||
 | 
					                    "id": 0,
 | 
				
			||||||
 | 
					                    "text": word,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return EventSourceResponse(create_response(streamer), media_type="text/event-stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_args():
 | 
				
			||||||
 | 
					    parser = ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "-c",
 | 
				
			||||||
 | 
					        "--checkpoint-path",
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default="CodeShell-7B-Chat",
 | 
				
			||||||
 | 
					        help="Checkpoint name or path, default to %(default)r",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--device",
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default="cpu",
 | 
				
			||||||
 | 
					        help="Device name."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--server-port",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=8080,
 | 
				
			||||||
 | 
					        help="Demo server port."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--server-name",
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default="127.0.0.1",
 | 
				
			||||||
 | 
					        help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
 | 
				
			||||||
 | 
					        " If you want other computers to access your server, use 0.0.0.0 instead.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--multi-turn",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        help="Enable multi-turn chat",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--cpu-embedding",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        help="Move Embedding layer to CPU"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--max-context",
 | 
				
			||||||
 | 
					        type=int,
 | 
				
			||||||
 | 
					        default=300,
 | 
				
			||||||
 | 
					        help="Max context length when using code completion",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    return args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    args = _get_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(
 | 
				
			||||||
 | 
					        args.checkpoint_path,
 | 
				
			||||||
 | 
					        trust_remote_code=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					        args.checkpoint_path,
 | 
				
			||||||
 | 
					        trust_remote_code=True,
 | 
				
			||||||
 | 
					        load_in_4bit=True,
 | 
				
			||||||
 | 
					        cpu_embedding=args.cpu_embedding
 | 
				
			||||||
 | 
					    ).eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = args.device
 | 
				
			||||||
 | 
					    multi_turn = args.multi_turn
 | 
				
			||||||
 | 
					    max_context = args.max_context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if device == 'xpu':
 | 
				
			||||||
 | 
					        import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = model.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.generation_config = GenerationConfig.from_pretrained(
 | 
				
			||||||
 | 
					        args.checkpoint_path,
 | 
				
			||||||
 | 
					        trust_remote_code=True,
 | 
				
			||||||
 | 
					        resume_download=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
 | 
				
			||||||
		Loading…
	
		Reference in a new issue