[LLM] add CausalLM and Speech UT (#8597)

This commit is contained in:
Song Jiaming 2023-07-25 11:22:36 +08:00 committed by GitHub
parent 9c897ac7db
commit 650b82fa6e
3 changed files with 65 additions and 8 deletions

View file

@ -43,6 +43,10 @@ env:
LLM_DIR: ./llm LLM_DIR: ./llm
ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/ ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/
ORIGINAL_REPLIT_CODE_PATH: ./llm/replit-code-v1-3b/
ORIGINAL_WHISPER_TINY_PATH: ./llm/whisper-tiny/
SPEECH_DATASET_PATH: ./llm/datasets/librispeech_asr_dummy
# A workflow run is made up of one or more jobs that can run sequentially or in parallel # A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs: jobs:
@ -100,6 +104,18 @@ jobs:
echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..." echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR
fi fi
if [ ! -d $ORIGINAL_REPLIT_CODE_PATH ]; then
echo "Directory $ORIGINAL_REPLIT_CODE_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_REPLIT_CODE_PATH:2} -P $LLM_DIR
fi
if [ ! -d $ORIGINAL_WHISPER_TINY_PATH ]; then
echo "Directory $ORIGINAL_WHISPER_TINY_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_WHISPER_TINY_PATH:2} -P $LLM_DIR
fi
if [ ! -d $SPEECH_DATASET_PATH ]; then
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${SPEECH_DATASET_PATH:2} -P $LLM_DIR
fi
- name: Run LLM cli test - name: Run LLM cli test
uses: ./.github/actions/llm/cli-test uses: ./.github/actions/llm/cli-test

View file

@ -17,10 +17,10 @@
import unittest import unittest
import os import os
import pytest
import time import time
import torch import torch
from bigdl.llm.transformers import AutoModel from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq
from transformers import AutoTokenizer from transformers import AutoTokenizer
class TestTransformersAPI(unittest.TestCase): class TestTransformersAPI(unittest.TestCase):
@ -32,11 +32,11 @@ class TestTransformersAPI(unittest.TestCase):
else: else:
self.n_threads = 2 self.n_threads = 2
def test_transformers_int4(self): def test_transformers_auto_model_int4(self):
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH') model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True) model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
input_str = "晚上睡不着应该怎么办" input_str = "Tell me the capital of France.\n\n"
with torch.inference_mode(): with torch.inference_mode():
st = time.time() st = time.time()
@ -46,8 +46,49 @@ class TestTransformersAPI(unittest.TestCase):
end = time.time() end = time.time()
print('Prompt:', input_str) print('Prompt:', input_str)
print('Output:', output_str) print('Output:', output_str)
print(f'Inference time: {end-st} s') print(f'Inference time: {end-st} s')
res = 'Paris' in output_str
self.assertTrue(res)
def test_transformers_auto_model_for_causal_lm_int4(self):
model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
input_str = 'def hello():\n print("hello world")\n'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
with torch.inference_mode():
st = time.time()
input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
print('Prompt:', input_str)
print('Output:', output_str)
print(f'Inference time: {end-st} s')
res = '\nhello()' in output_str
self.assertTrue(res)
def test_transformers_auto_model_for_speech_seq2seq_int4(self):
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_from_disk
model_path = os.environ.get('ORIGINAL_WHISPER_TINY_PATH')
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
processor = WhisperProcessor.from_pretrained(model_path)
ds = load_from_disk(dataset_path)
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
with torch.inference_mode():
st = time.time()
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
end = time.time()
print('Output:', transcription)
print(f'Inference time: {end-st} s')
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
self.assertTrue(res)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])

View file

@ -9,13 +9,13 @@ set -e
echo "# Start testing inference" echo "# Start testing inference"
start=$(date "+%s") start=$(date "+%s")
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers_int4" python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers"
if [ -z "$THREAD_NUM" ]; then if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2 THREAD_NUM=2
fi fi
export OMP_NUM_THREADS=$THREAD_NUM export OMP_NUM_THREADS=$THREAD_NUM
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers_int4 taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers
now=$(date "+%s") now=$(date "+%s")
time=$((now-start)) time=$((now-start))