[LLM] add CausalLM and Speech UT (#8597)
This commit is contained in:
parent
9c897ac7db
commit
650b82fa6e
3 changed files with 65 additions and 8 deletions
16
.github/workflows/llm_unit_tests_linux.yml
vendored
16
.github/workflows/llm_unit_tests_linux.yml
vendored
|
|
@ -43,6 +43,10 @@ env:
|
||||||
|
|
||||||
LLM_DIR: ./llm
|
LLM_DIR: ./llm
|
||||||
ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/
|
ORIGINAL_CHATGLM2_6B_PATH: ./llm/chatglm2-6b/
|
||||||
|
ORIGINAL_REPLIT_CODE_PATH: ./llm/replit-code-v1-3b/
|
||||||
|
ORIGINAL_WHISPER_TINY_PATH: ./llm/whisper-tiny/
|
||||||
|
SPEECH_DATASET_PATH: ./llm/datasets/librispeech_asr_dummy
|
||||||
|
|
||||||
|
|
||||||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -100,6 +104,18 @@ jobs:
|
||||||
echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..."
|
echo "Directory $ORIGINAL_CHATGLM2_6B_PATH not found. Downloading from FTP server..."
|
||||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR
|
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_CHATGLM2_6B_PATH:2} -P $LLM_DIR
|
||||||
fi
|
fi
|
||||||
|
if [ ! -d $ORIGINAL_REPLIT_CODE_PATH ]; then
|
||||||
|
echo "Directory $ORIGINAL_REPLIT_CODE_PATH not found. Downloading from FTP server..."
|
||||||
|
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_REPLIT_CODE_PATH:2} -P $LLM_DIR
|
||||||
|
fi
|
||||||
|
if [ ! -d $ORIGINAL_WHISPER_TINY_PATH ]; then
|
||||||
|
echo "Directory $ORIGINAL_WHISPER_TINY_PATH not found. Downloading from FTP server..."
|
||||||
|
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${ORIGINAL_WHISPER_TINY_PATH:2} -P $LLM_DIR
|
||||||
|
fi
|
||||||
|
if [ ! -d $SPEECH_DATASET_PATH ]; then
|
||||||
|
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
|
||||||
|
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/${SPEECH_DATASET_PATH:2} -P $LLM_DIR
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Run LLM cli test
|
- name: Run LLM cli test
|
||||||
uses: ./.github/actions/llm/cli-test
|
uses: ./.github/actions/llm/cli-test
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,10 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
|
import pytest
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
from bigdl.llm.transformers import AutoModel
|
from bigdl.llm.transformers import AutoModel, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
class TestTransformersAPI(unittest.TestCase):
|
class TestTransformersAPI(unittest.TestCase):
|
||||||
|
|
@ -32,11 +32,11 @@ class TestTransformersAPI(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
self.n_threads = 2
|
self.n_threads = 2
|
||||||
|
|
||||||
def test_transformers_int4(self):
|
def test_transformers_auto_model_int4(self):
|
||||||
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
input_str = "晚上睡不着应该怎么办"
|
input_str = "Tell me the capital of France.\n\n"
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
st = time.time()
|
st = time.time()
|
||||||
|
|
@ -46,8 +46,49 @@ class TestTransformersAPI(unittest.TestCase):
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print('Prompt:', input_str)
|
print('Prompt:', input_str)
|
||||||
print('Output:', output_str)
|
print('Output:', output_str)
|
||||||
print(f'Inference time: {end-st} s')
|
print(f'Inference time: {end-st} s')
|
||||||
|
res = 'Paris' in output_str
|
||||||
|
self.assertTrue(res)
|
||||||
|
|
||||||
|
def test_transformers_auto_model_for_causal_lm_int4(self):
|
||||||
|
model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
input_str = 'def hello():\n print("hello world")\n'
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||||
|
with torch.inference_mode():
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
input_ids = tokenizer.encode(input_str, return_tensors="pt")
|
||||||
|
output = model.generate(input_ids, do_sample=False, max_new_tokens=32)
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
end = time.time()
|
||||||
|
print('Prompt:', input_str)
|
||||||
|
print('Output:', output_str)
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
res = '\nhello()' in output_str
|
||||||
|
self.assertTrue(res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_auto_model_for_speech_seq2seq_int4(self):
|
||||||
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||||
|
from datasets import load_from_disk
|
||||||
|
model_path = os.environ.get('ORIGINAL_WHISPER_TINY_PATH')
|
||||||
|
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
|
||||||
|
processor = WhisperProcessor.from_pretrained(model_path)
|
||||||
|
ds = load_from_disk(dataset_path)
|
||||||
|
sample = ds[0]["audio"]
|
||||||
|
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
|
||||||
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True)
|
||||||
|
with torch.inference_mode():
|
||||||
|
st = time.time()
|
||||||
|
predicted_ids = model.generate(input_features)
|
||||||
|
# decode token ids to text
|
||||||
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
|
||||||
|
end = time.time()
|
||||||
|
print('Output:', transcription)
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
res = 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
||||||
|
self.assertTrue(res)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,13 @@ set -e
|
||||||
echo "# Start testing inference"
|
echo "# Start testing inference"
|
||||||
start=$(date "+%s")
|
start=$(date "+%s")
|
||||||
|
|
||||||
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers_int4"
|
python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k "not test_transformers"
|
||||||
|
|
||||||
if [ -z "$THREAD_NUM" ]; then
|
if [ -z "$THREAD_NUM" ]; then
|
||||||
THREAD_NUM=2
|
THREAD_NUM=2
|
||||||
fi
|
fi
|
||||||
export OMP_NUM_THREADS=$THREAD_NUM
|
export OMP_NUM_THREADS=$THREAD_NUM
|
||||||
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers_int4
|
taskset -c 0-$((THREAD_NUM - 1)) python -m pytest -s ${LLM_INFERENCE_TEST_DIR} -k test_transformers
|
||||||
|
|
||||||
now=$(date "+%s")
|
now=$(date "+%s")
|
||||||
time=$((now-start))
|
time=$((now-start))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue