[LLM] add CausalLM and Speech UT (#8597)
This commit is contained in:
parent
9c897ac7db
commit
650b82fa6e
3 changed files with 65 additions and 8 deletions
16
.github/workflows/llm_unit_tests_linux.yml
vendored
16
.github/workflows/llm_unit_tests_linux.yml
vendored
|
|
@ -43,6 +43,10 @@ env:
|
|||
|
||||
LLM_DIR: ./llm
|
||||
ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/
|
||||
ORIGINAL_REPLIT_CODE_PATH: ./llm/replit-code-v1-3b/
|
||||
ORIGINAL_WHISPER_TINY_PATH: ./llm/whisper-tiny/
|
||||
SPEECH_DATASET_PATH: ./llm/datasets/librispeech_asr_dummy
|
||||
|
||||
|
||||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
||||
jobs:
|
||||
|
|
@ -100,6 +104,18 @@ jobs:
|
|||
echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR
|
||||
fi
|
||||
if [ ! -d $ORIGINAL_REPLIT_CODE_PATH ]; then
|
||||
echo "Directory $ORIGINAL_REPLIT_CODE_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_REPLIT_CODE_PATH:2} -P $LLM_DIR
|
||||
fi
|
||||
if [ ! -d $ORIGINAL_WHISPER_TINY_PATH ]; then
|
||||
echo "Directory $ORIGINAL_WHISPER_TINY_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_WHISPER_TINY_PATH:2} -P $LLM_DIR
|
||||
fi
|
||||
if [ ! -d $SPEECH_DATASET_PATH ]; then
|
||||
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${SPEECH_DATASET_PATH:2} -P $LLM_DIR
|
||||
fi
|
||||
|
||||
- name: Run LLM cli test
|
||||
uses: ./.github/actions/llm/cli-test
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
import unittest
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
from bigdl.llm.transformers import AutoModel
|
||||
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
class TestTransformersAPI(unittest.TestCase):
|
||||
|
|
@ -32,11 +32,11 @@ class TestTransformersAPI(unittest.TestCase):
|
|||
else:
|
||||
self.n_threads = 2
|
||||
|
||||
def test_transformers_int4(self):
|
||||
def test_transformers_auto_model_int4(self):
|
||||
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
input_str = "晚上睡不着应该怎么办"
|
||||
input_str = "Tell me the capital of France.\n\n"
|
||||
|
||||
with torch.inference_mode():
|
||||
st = time.time()
|
||||
|
|
@ -47,7 +47,48 @@ class TestTransformersAPI(unittest.TestCase):
|
|||
print('Prompt:', input_str)
|
||||
print('Output:', output_str)
|
||||
print(f'Inference time: {end-st} s')
|
||||
res = 'Paris' in output_str
|
||||
self.assertTrue(res)
|
||||
|
||||
def test_transformers_auto_model_for_causal_lm_int4(self):
|
||||
model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
input_str = 'def hello():\n print("hello world")\n'
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||
with torch.inference_mode():
|
||||
|
||||
st = time.time()
|
||||
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
end = time.time()
|
||||
print('Prompt:', input_str)
|
||||
print('Output:', output_str)
|
||||
print(f'Inference time: {end-st} s')
|
||||
res = '\nhello()' in output_str
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
def test_transformers_auto_model_for_speech_seq2seq_int4(self):
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
from datasets import load_from_disk
|
||||
model_path = os.environ.get('ORIGINAL_WHISPER_TINY_PATH')
|
||||
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
|
||||
processor = WhisperProcessor.from_pretrained(model_path)
|
||||
ds = load_from_disk(dataset_path)
|
||||
sample = ds[0]["audio"]
|
||||
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||
with torch.inference_mode():
|
||||
st = time.time()
|
||||
predicted_ids = model.generate(input_features)
|
||||
# decode token ids to text
|
||||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
|
||||
end = time.time()
|
||||
print('Output:', transcription)
|
||||
print(f'Inference time: {end-st} s')
|
||||
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
||||
self.assertTrue(res)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ set -e
|
|||
echo "# Start testing inference"
|
||||
start=$(date "+%s")
|
||||
|
||||
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers_int4"
|
||||
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers"
|
||||
|
||||
if [ -z "$THREAD_NUM" ]; then
|
||||
THREAD_NUM=2
|
||||
fi
|
||||
export OMP_NUM_THREADS=$THREAD_NUM
|
||||
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers_int4
|
||||
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers
|
||||
|
||||
now=$(date "+%s")
|
||||
time=$((now-start))
|
||||
|
|
|
|||
Loading…
Reference in a new issue