[LLM] Add UTs of load_low_bit for transformers-style API (#10001)

* Add uts for transformers api load_low_bit generation

* Small fixes

* Remove replit-code for CPU tests due to current load_low_bit issue on MPT

* Small change

* Small reorganization to llm unit tests on CPU

* Small fixes
This commit is contained in:
Yuwen Hu 2024-01-29 10:18:23 +08:00 committed by GitHub
parent d720554d43
commit c6d4f91777
6 changed files with 121 additions and 66 deletions

View file

@ -1,59 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pytest
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizer, AutoTokenizer
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
chatglm2_6b_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
replit_code_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
@pytest.mark.parametrize("Model, Tokenizer, model_path, prompt", [
(AutoModelForCausalLM, LlamaTokenizer, llama_model_path, prompt),
(AutoModelForCausalLM, AutoTokenizer, bloom_model_path, prompt),
(AutoModel, AutoTokenizer, chatglm2_6b_model_path, prompt),
(AutoModelForCausalLM, AutoTokenizer, replit_code_model_path, prompt)
])
def test_optimize_model(Model, Tokenizer, model_path, prompt):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
logits_base_model = (model(input_ids)).logits
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
logits_optimized_model = (model(input_ids)).logits
diff = abs(logits_base_model - logits_optimized_model).flatten()
assert any(diff) is False
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -17,11 +17,13 @@
import unittest
import os
import pytest
import tempfile
import time
import torch
import pytest
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq
from transformers import AutoTokenizer
from transformers import AutoTokenizer, LlamaTokenizer
class TestTransformersAPI(unittest.TestCase):
@ -109,5 +111,60 @@ class TestTransformersAPI(unittest.TestCase):
res = 'Paris' in output_str
self.assertTrue(res)
@pytest.mark.parametrize('prompt, answer', [
('What is the capital of France?\n\n', 'Paris')
])
@pytest.mark.parametrize('Model, Tokenizer, model_path',[
(AutoModel, AutoTokenizer, os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')),
])
def test_load_low_bit_completion(Model, Tokenizer, model_path, prompt, answer):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
with tempfile.TemporaryDirectory() as tempdir:
model.save_low_bit(tempdir)
loaded_model = Model.load_low_bit(tempdir,
optimize_model=True,
trust_remote_code=True)
with torch.inference_mode():
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = loaded_model.generate(input_ids, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
assert answer in output_str
prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
@pytest.mark.parametrize("Model, Tokenizer, model_path, prompt", [
(AutoModelForCausalLM, LlamaTokenizer, os.environ.get('LLAMA_ORIGIN_PATH'), prompt),
(AutoModelForCausalLM, AutoTokenizer, os.environ.get('BLOOM_ORIGIN_PATH'), prompt),
(AutoModel, AutoTokenizer, os.environ.get('ORIGINAL_CHATGLM2_6B_PATH'), prompt),
(AutoModelForCausalLM, AutoTokenizer, os.environ.get('ORIGINAL_REPLIT_CODE_PATH'), prompt)
])
def test_optimize_model(Model, Tokenizer, model_path, prompt):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
logits_base_model = (model(input_ids)).logits
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
logits_optimized_model = (model(input_ids)).logits
diff = abs(logits_base_model - logits_optimized_model).flatten()
assert any(diff) is False
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -16,6 +16,8 @@
import os
import pytest
import tempfile
import torch
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
@ -48,6 +50,31 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt):
assert any(diff) is False
@pytest.mark.parametrize('prompt, answer', [
('What is the capital of France?\n\n', 'Paris')
])
@pytest.mark.parametrize('Model, Tokenizer, model_path',[
(AutoModelForCausalLM, AutoTokenizer, mistral_model_path),
])
def test_load_low_bit_completion(Model, Tokenizer, model_path, prompt, answer):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
with tempfile.TemporaryDirectory() as tempdir:
model.save_low_bit(tempdir)
loaded_model = Model.load_low_bit(tempdir,
optimize_model=True,
trust_remote_code=True)
with torch.inference_mode():
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = loaded_model.generate(input_ids, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
assert answer in output_str
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -17,6 +17,7 @@
import os, time
import pytest
import tempfile
import torch
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSpeechSeq2Seq
@ -50,6 +51,36 @@ def test_completion(Model, Tokenizer, model_path, prompt, answer):
assert answer in output_str
@pytest.mark.parametrize('prompt, answer', [
('What is the capital of France?\n\n', 'Paris')
])
@pytest.mark.parametrize('Model, Tokenizer, model_path',[
(AutoModelForCausalLM, LlamaTokenizer, os.environ.get('LLAMA2_7B_ORIGIN_PATH')),
(AutoModel, AutoTokenizer, os.environ.get('CHATGLM2_6B_ORIGIN_PATH')),
])
def test_load_low_bit_completion(Model, Tokenizer, model_path, prompt, answer):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
with tempfile.TemporaryDirectory() as tempdir:
model.save_low_bit(tempdir)
loaded_model = Model.load_low_bit(tempdir,
optimize_model=True,
trust_remote_code=True)
with torch.inference_mode():
loaded_model = loaded_model.to(device)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output = loaded_model.generate(input_ids, max_new_tokens=32)
loaded_model.to('cpu') # deallocate gpu memory
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
assert answer in output_str
def test_transformers_auto_model_for_speech_seq2seq_int4():
with torch.inference_mode():
from transformers import WhisperProcessor

View file

@ -9,18 +9,17 @@ set -e
echo "# Start testing inference"
start=$(date "+%s")
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers" -v \
--ignore=${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_call_models.py -v
if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2
fi
export OMP_NUM_THREADS=$THREAD_NUM
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers -v \
--ignore=${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_transformers_api.py -v
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_optimize_model_api.py -v
python -m pip install transformers==4.34.0
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_optimize_mistral.py -v
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}/test_transformesr_api_434.py -v
python -m pip install transformers==4.31.0
now=$(date "+%s")