45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
from ipex_llm.langchain.vllm.vllm import VLLM
|
|
from langchain.chains import LLMChain
|
|
from langchain_core.prompts import PromptTemplate
|
|
import argparse
|
|
|
|
def main(args):
|
|
llm = VLLM(
|
|
model=args.model_path,
|
|
trust_remote_code=True, # mandatory for hf models
|
|
max_new_tokens=128,
|
|
top_k=10,
|
|
top_p=0.95,
|
|
temperature=0.8,
|
|
max_model_len=2048,
|
|
enforce_eager=True,
|
|
load_in_low_bit=args.load_in_low_bit,
|
|
device="xpu",
|
|
tensor_parallel_size=args.tensor_parallel_size,
|
|
)
|
|
|
|
print(llm.invoke(args.question))
|
|
|
|
template = """Question: {question}
|
|
|
|
Answer: Let's think step by step."""""
|
|
prompt = PromptTemplate.from_template(template)
|
|
|
|
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
|
|
|
print(llm_chain.invoke("Who was the US president in the year the first Pokemon game was released?"))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Langchain integrated vLLM example')
|
|
parser.add_argument('-m','--model-path', type=str, required=True,
|
|
help='the path to transformers model')
|
|
parser.add_argument('-q', '--question', type=str, default='What is the capital of France?', help='qustion you want to ask.')
|
|
parser.add_argument('-t', '--max-tokens', type=int, default=128, help='max tokens to generate')
|
|
parser.add_argument('-p', '--tensor-parallel-size', type=int, default=1, help="vLLM tensor parallel size")
|
|
parser.add_argument('-l', '--load-in-low-bit', type=str, default='sym_int4', help="low bit format")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|
|
|