LLM: add chatglm-6b example for transformer_int4 usage (#8392)
* add example for chatglm-6b * fix
This commit is contained in:
parent
19e19efb4c
commit
b9eae23c79
1 changed files with 38 additions and 13 deletions
|
|
@ -16,21 +16,46 @@
|
|||
|
||||
import torch
|
||||
import os
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
import time
|
||||
import argparse
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
|
||||
from transformers import LlamaTokenizer, AutoTokenizer
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_path = 'decapoda-research/llama-7b-hf'
|
||||
parser = argparse.ArgumentParser(description='Transformer INT4 example')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="decapoda-research/llama-7b-hf",
|
||||
choices=['decapoda-research/llama-7b-hf', 'THUDM/chatglm-6b'],
|
||||
help='The huggingface repo id for the larga language model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
if model_path == 'decapoda-research/llama-7b-hf':
|
||||
# load_in_4bit=True in bigdl.llm.transformers will convert
|
||||
# the relevant layers in the model into int4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
|
||||
# load_in_4bit=True in bigdl.llm.transformers will convert
|
||||
# the relevant layers in the model into int4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
|
||||
|
||||
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
|
||||
with torch.inference_mode():
|
||||
st = time.time()
|
||||
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.time()
|
||||
print(output_str)
|
||||
print(f'Inference time: {end-st} s')
|
||||
elif model_path == 'THUDM/chatglm-6b':
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
print(output_str)
|
||||
input_str = "晚上睡不着应该怎么办"
|
||||
|
||||
with torch.inference_mode():
|
||||
st = time.time()
|
||||
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.time()
|
||||
print(output_str)
|
||||
print(f'Inference time: {end-st} s')
|
||||
|
|
|
|||
Loading…
Reference in a new issue