ipex-llm/python/llm/example/GPU/LangChain/vllm.py
Guancheng Fu fabc395d0d
add langchain vllm interface (#11121)
* done

* fix

* fix

* add vllm

* add langchain vllm exampels

* add docs

* temp
2024-05-24 17:19:27 +08:00

45 lines
1.6 KiB
Python

from ipex_llm.langchain.vllm.vllm import VLLM
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
import argparse
def main(args):
llm = VLLM(
model=args.model_path,
trust_remote_code=True, # mandatory for hf models
max_new_tokens=128,
top_k=10,
top_p=0.95,
temperature=0.8,
max_model_len=2048,
enforce_eager=True,
load_in_low_bit=args.load_in_low_bit,
device="xpu",
tensor_parallel_size=args.tensor_parallel_size,
)
print(llm.invoke(args.question))
template = """Question: {question}
Answer: Let's think step by step."""""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
print(llm_chain.invoke("Who was the US president in the year the first Pokemon game was released?"))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Langchain integrated vLLM example')
parser.add_argument('-m','--model-path', type=str, required=True,
help='the path to transformers model')
parser.add_argument('-q', '--question', type=str, default='What is the capital of France?', help='qustion you want to ask.')
parser.add_argument('-t', '--max-tokens', type=int, default=128, help='max tokens to generate')
parser.add_argument('-p', '--tensor-parallel-size', type=int, default=1, help="vLLM tensor parallel size")
parser.add_argument('-l', '--load-in-low-bit', type=str, default='sym_int4', help="low bit format")
args = parser.parse_args()
main(args)