Add inference test for Whisper model on Arc (#9330)

* Add inference test for Whisper model

* Remove unnecessary inference time measurement
This commit is contained in:
Cheen Hau, 俊豪 2023-11-03 10:15:52 +08:00 committed by GitHub
parent 63411dff75
commit 8f23fb04dc
3 changed files with 35 additions and 6 deletions

View file

@ -221,11 +221,13 @@ jobs:
run: |
echo "DATASET_DIR=${ORIGIN_DIR}/../datasets" >> "$GITHUB_ENV"
echo "ABIRATE_ENGLISH_QUOTES_PATH=${ORIGIN_DIR}/../datasets/abirate_english_quotes" >> "$GITHUB_ENV"
echo "SPEECH_DATASET_PATH=${ORIGIN_DIR}/../datasets/librispeech_asr_dummy" >> "$GITHUB_ENV"
echo "LLAMA2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Llama-2-7b-chat-hf" >> "$GITHUB_ENV"
echo "CHATGLM2_6B_ORIGIN_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV"
echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV"
echo "WHISPER_TINY_ORIGIN_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
- name: Checkout repo
uses: actions/checkout@v3
@ -275,6 +277,10 @@ jobs:
echo "Directory $MPT_7B_ORIGIN_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/mpt-7b-chat -P $ORIGIN_DIR
fi
if [ ! -d $WHISPER_TINY_ORIGIN_PATH ]; then
echo "Directory $WHISPER_TINY_ORIGIN_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
fi
if [ ! -d $DATASET_DIR ]; then
mkdir -p $DATASET_DIR
fi
@ -282,12 +288,16 @@ jobs:
echo "Directory $ABIRATE_ENGLISH_QUOTES_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/abirate_english_quotes -P $DATASET_DIR
fi
if [ ! -d $SPEECH_DATASET_PATH ]; then
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR
fi
- name: Run LLM inference test
shell: bash
run: |
source /opt/intel/oneapi/setvars.sh
python -m pip install expecttest einops
python -m pip install expecttest einops librosa
bash python/llm/test/run-llm-inference-tests-gpu.sh
- name: Run LLM example tests

View file

@ -15,10 +15,10 @@
#
import os
import os, time
import pytest
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSpeechSeq2Seq
from transformers import LlamaTokenizer, AutoTokenizer
device = os.environ['DEVICE']
@ -41,14 +41,33 @@ def test_completion(Model, Tokenizer, model_path, prompt, answer):
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
model = model.to(device) # deallocate gpu memory
model = model.to(device)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output = model.generate(input_ids, max_new_tokens=32)
model.to('cpu')
model.to('cpu') # deallocate gpu memory
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
assert answer in output_str
def test_transformers_auto_model_for_speech_seq2seq_int4():
from transformers import WhisperProcessor
from datasets import load_from_disk
model_path = os.environ.get('WHISPER_TINY_ORIGIN_PATH')
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
processor = WhisperProcessor.from_pretrained(model_path)
ds = load_from_disk(dataset_path)
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True, optimize_model=True)
model = model.to(device)
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
model.to('cpu')
print('Output:', transcription)
assert 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -17,7 +17,7 @@ if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2
fi
export OMP_NUM_THREADS=$THREAD_NUM
pytest ${LLM_INFERENCE_TEST_DIR} -v -s
pytest ${LLM_INFERENCE_TEST_DIR} -v -s
now=$(date "+%s")
time=$((now-start))