[LLM] add CausalLM and Speech UT (#8597)

This commit is contained in:
Song Jiaming 2023-07-25 11:22:36 +08:00 committed by GitHub
parent 9c897ac7db
commit 650b82fa6e
3 changed files with 65 additions and 8 deletions

View file

@ -43,6 +43,10 @@ env:
LLM_DIR: ./llm
ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/
ORIGINAL_REPLIT_CODE_PATH: ./llm/replit-code-v1-3b/
ORIGINAL_WHISPER_TINY_PATH: ./llm/whisper-tiny/
SPEECH_DATASET_PATH: ./llm/datasets/librispeech_asr_dummy
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
@ -100,6 +104,18 @@ jobs:
echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR
fi
if [ ! -d $ORIGINAL_REPLIT_CODE_PATH ]; then
echo "Directory $ORIGINAL_REPLIT_CODE_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_REPLIT_CODE_PATH:2} -P $LLM_DIR
fi
if [ ! -d $ORIGINAL_WHISPER_TINY_PATH ]; then
echo "Directory $ORIGINAL_WHISPER_TINY_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_WHISPER_TINY_PATH:2} -P $LLM_DIR
fi
if [ ! -d $SPEECH_DATASET_PATH ]; then
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${SPEECH_DATASET_PATH:2} -P $LLM_DIR
fi
- name: Run LLM cli test
uses: ./.github/actions/llm/cli-test

View file

@ -17,10 +17,10 @@
import unittest
import os
import pytest
import time
import torch
from bigdl.llm.transformers import AutoModel
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq
from transformers import AutoTokenizer
class TestTransformersAPI(unittest.TestCase):
@ -32,11 +32,11 @@ class TestTransformersAPI(unittest.TestCase):
else:
self.n_threads = 2
def test_transformers_int4(self):
def test_transformers_auto_model_int4(self):
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
input_str = "晚上睡不着应该怎么办"
input_str = "Tell me the capital of France.\n\n"
with torch.inference_mode():
st = time.time()
@ -47,7 +47,48 @@ class TestTransformersAPI(unittest.TestCase):
print('Prompt:', input_str)
print('Output:', output_str)
print(f'Inference time: {end-st} s')
res = 'Paris' in output_str
self.assertTrue(res)
def test_transformers_auto_model_for_causal_lm_int4(self):
model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
input_str = 'def hello():\n print("hello world")\n'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
with torch.inference_mode():
st = time.time()
input_ids = tokenizer.encode(input_str, return_tensors="pt")
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.time()
print('Prompt:', input_str)
print('Output:', output_str)
print(f'Inference time: {end-st} s')
res = '\nhello()' in output_str
self.assertTrue(res)
def test_transformers_auto_model_for_speech_seq2seq_int4(self):
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_from_disk
model_path = os.environ.get('ORIGINAL_WHISPER_TINY_PATH')
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
processor = WhisperProcessor.from_pretrained(model_path)
ds = load_from_disk(dataset_path)
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
with torch.inference_mode():
st = time.time()
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
end = time.time()
print('Output:', transcription)
print(f'Inference time: {end-st} s')
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
self.assertTrue(res)
if __name__ == '__main__':
unittest.main()
pytest.main([__file__])

View file

@ -9,13 +9,13 @@ set -e
echo "# Start testing inference"
start=$(date "+%s")
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers_int4"
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers"
if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2
fi
export OMP_NUM_THREADS=$THREAD_NUM
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers_int4
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers
now=$(date "+%s")
time=$((now-start))