LLM: add chatglm-6b example for transformer_int4 usage (#8392)

* add example for chatglm-6b

* fix
This commit is contained in:
Ruonan Wang 2023-06-26 13:46:43 +08:00 committed by GitHub
parent 19e19efb4c
commit b9eae23c79

View file

@ -16,21 +16,46 @@
import torch import torch
import os import os
from bigdl.llm.transformers import AutoModelForCausalLM import time
from transformers import LlamaTokenizer import argparse
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizer, AutoTokenizer
if __name__ == '__main__': if __name__ == '__main__':
model_path = 'decapoda-research/llama-7b-hf' parser = argparse.ArgumentParser(description='Transformer INT4 example')
parser.add_argument('--repo-id-or-model-path', type=str, default="decapoda-research/llama-7b-hf",
choices=['decapoda-research/llama-7b-hf', 'THUDM/chatglm-6b'],
help='The huggingface repo id for the larga language model to be downloaded'
', or the path to the huggingface checkpoint folder')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if model_path == 'decapoda-research/llama-7b-hf':
# load_in_4bit=True in bigdl.llm.transformers will convert
# the relevant layers in the model into int4 format
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path)
# load_in_4bit=True in bigdl.llm.transformers will convert input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
# the relevant layers in the model into int4 format
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True)
tokenizer = LlamaTokenizer.from_pretrained(model_path)
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" with torch.inference_mode():
st = time.time()
input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
print(output_str)
print(f'Inference time: {end-st} s')
elif model_path == 'THUDM/chatglm-6b':
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
with torch.inference_mode(): input_str = "晚上睡不着应该怎么办"
input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32) with torch.inference_mode():
output_str = tokenizer.decode(output[0], skip_special_tokens=True) st = time.time()
print(output_str) input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
print(output_str)
print(f'Inference time: {end-st} s')