85 lines
		
	
	
		
			No EOL
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			No EOL
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Copyright 2016 The BigDL Authors.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
#
 | 
						|
 | 
						|
import argparse
 | 
						|
 | 
						|
import gradio as gr
 | 
						|
from openai import OpenAI
 | 
						|
 | 
						|
# Argument parser setup
 | 
						|
parser = argparse.ArgumentParser(
 | 
						|
    description='Chatbot Interface with Customizable Parameters')
 | 
						|
parser.add_argument('--model-url',
 | 
						|
                    type=str,
 | 
						|
                    default='http://localhost:8000/v1',
 | 
						|
                    help='Model URL')
 | 
						|
parser.add_argument('-m',
 | 
						|
                    '--model',
 | 
						|
                    type=str,
 | 
						|
                    required=True,
 | 
						|
                    help='Model name for the chatbot')
 | 
						|
parser.add_argument("--host", type=str, default=None)
 | 
						|
parser.add_argument("--port", type=int, default=8001)
 | 
						|
 | 
						|
# Parse the arguments
 | 
						|
args = parser.parse_args()
 | 
						|
 | 
						|
# Set OpenAI's API key and API base to use vLLM's API server.
 | 
						|
openai_api_key = "EMPTY"
 | 
						|
openai_api_base = args.model_url
 | 
						|
 | 
						|
# Create an OpenAI client to interact with the API server
 | 
						|
client = OpenAI(
 | 
						|
    api_key=openai_api_key,
 | 
						|
    base_url=openai_api_base,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def predict(message, history):
 | 
						|
    # Convert chat history to OpenAI format
 | 
						|
    history_openai_format = [{
 | 
						|
        "role": "system",
 | 
						|
        "content": "You are a great ai assistant."
 | 
						|
    }]
 | 
						|
    for human, assistant in history:
 | 
						|
        history_openai_format.append({"role": "user", "content": human})
 | 
						|
        history_openai_format.append({
 | 
						|
            "role": "assistant",
 | 
						|
            "content": assistant
 | 
						|
        })
 | 
						|
    history_openai_format.append({"role": "user", "content": message})
 | 
						|
 | 
						|
    # Create a chat completion request and send it to the API server
 | 
						|
    stream = client.chat.completions.create(
 | 
						|
        model=args.model,  # Model name to use
 | 
						|
        messages=history_openai_format,  # Chat history
 | 
						|
        stream=True,  # Stream response
 | 
						|
    )
 | 
						|
 | 
						|
    # Read and return generated text from response stream
 | 
						|
    partial_message = ""
 | 
						|
    for chunk in stream:
 | 
						|
        # import pdb
 | 
						|
        # pdb.set_trace()
 | 
						|
        # partial_message += (chunk.delta['content'] or "")
 | 
						|
        partial_message += (chunk.choices[0].delta.content or "")
 | 
						|
        yield partial_message
 | 
						|
 | 
						|
 | 
						|
# Create and launch a chat interface with Gradio
 | 
						|
gr.ChatInterface(predict).queue().launch(server_name=args.host,
 | 
						|
                                         server_port=args.port,
 | 
						|
                                         share=True) |