From b9eae23c7933b2f0f6df7b1d3d47e7d8dd596313 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Mon, 26 Jun 2023 13:46:43 +0800 Subject: [PATCH] LLM: add chatglm-6b example for transformer_int4 usage (#8392) * add example for chatglm-6b * fix --- python/llm/example/transformers_int4.py | 51 ++++++++++++++++++------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/python/llm/example/transformers_int4.py b/python/llm/example/transformers_int4.py index 6d128af1..f4014b0c 100644 --- a/python/llm/example/transformers_int4.py +++ b/python/llm/example/transformers_int4.py @@ -16,21 +16,46 @@ import torch import os -from bigdl.llm.transformers import AutoModelForCausalLM -from transformers import LlamaTokenizer +import time +import argparse +from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel +from transformers import LlamaTokenizer, AutoTokenizer if __name__ == '__main__': - model_path = 'decapoda-research/llama-7b-hf' + parser = argparse.ArgumentParser(description='Transformer INT4 example') + parser.add_argument('--repo-id-or-model-path', type=str, default="decapoda-research/llama-7b-hf", + choices=['decapoda-research/llama-7b-hf', 'THUDM/chatglm-6b'], + help='The huggingface repo id for the larga language model to be downloaded' + ', or the path to the huggingface checkpoint folder') + args = parser.parse_args() + model_path = args.repo_id_or_model_path + if model_path == 'decapoda-research/llama-7b-hf': + # load_in_4bit=True in bigdl.llm.transformers will convert + # the relevant layers in the model into int4 format + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True) + tokenizer = LlamaTokenizer.from_pretrained(model_path) - # load_in_4bit=True in bigdl.llm.transformers will convert - # the relevant layers in the model into int4 format - model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True) - tokenizer = LlamaTokenizer.from_pretrained(model_path) + input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" - input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" + with torch.inference_mode(): + st = time.time() + input_ids = tokenizer.encode(input_str, return_tensors="pt") + output = model.generate(input_ids, do_sample=False, max_new_tokens=32) + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.time() + print(output_str) + print(f'Inference time: {end-st} s') + elif model_path == 'THUDM/chatglm-6b': + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - with torch.inference_mode(): - input_ids = tokenizer.encode(input_str, return_tensors="pt") - output = model.generate(input_ids, do_sample=False, max_new_tokens=32) - output_str = tokenizer.decode(output[0], skip_special_tokens=True) - print(output_str) + input_str = "晚上睡不着应该怎么办" + + with torch.inference_mode(): + st = time.time() + input_ids = tokenizer.encode(input_str, return_tensors="pt") + output = model.generate(input_ids, do_sample=False, max_new_tokens=32) + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.time() + print(output_str) + print(f'Inference time: {end-st} s')