Add inference test for Whisper model on Arc (#9330)

* Add inference test for Whisper model

* Remove unnecessary inference time measurement
This commit is contained in:
Cheen Hau, 俊豪 2023-11-03 10:15:52 +08:00 committed by GitHub
parent 63411dff75
commit 8f23fb04dc
3 changed files with 35 additions and 6 deletions

View file

@ -221,11 +221,13 @@ jobs:
run: | run: |
echo "DATASET_DIR=${ORIGIN_DIR}/../datasets" >> "$GITHUB_ENV" echo "DATASET_DIR=${ORIGIN_DIR}/../datasets" >> "$GITHUB_ENV"
echo "ABIRATE_ENGLISH_QUOTES_PATH=${ORIGIN_DIR}/../datasets/abirate_english_quotes" >> "$GITHUB_ENV" echo "ABIRATE_ENGLISH_QUOTES_PATH=${ORIGIN_DIR}/../datasets/abirate_english_quotes" >> "$GITHUB_ENV"
echo "SPEECH_DATASET_PATH=${ORIGIN_DIR}/../datasets/librispeech_asr_dummy" >> "$GITHUB_ENV"
echo "LLAMA2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Llama-2-7b-chat-hf" >> "$GITHUB_ENV" echo "LLAMA2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Llama-2-7b-chat-hf" >> "$GITHUB_ENV"
echo "CHATGLM2_6B_ORIGIN_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV" echo "CHATGLM2_6B_ORIGIN_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV" echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV"
echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV" echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV"
echo "WHISPER_TINY_ORIGIN_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -275,6 +277,10 @@ jobs:
echo "Directory $MPT_7B_ORIGIN_PATH not found. Downloading from FTP server..." echo "Directory $MPT_7B_ORIGIN_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/mpt-7b-chat -P $ORIGIN_DIR wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/mpt-7b-chat -P $ORIGIN_DIR
fi fi
if [ ! -d $WHISPER_TINY_ORIGIN_PATH ]; then
echo "Directory $WHISPER_TINY_ORIGIN_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
fi
if [ ! -d $DATASET_DIR ]; then if [ ! -d $DATASET_DIR ]; then
mkdir -p $DATASET_DIR mkdir -p $DATASET_DIR
fi fi
@ -282,12 +288,16 @@ jobs:
echo "Directory $ABIRATE_ENGLISH_QUOTES_PATH not found. Downloading from FTP server..." echo "Directory $ABIRATE_ENGLISH_QUOTES_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/abirate_english_quotes -P $DATASET_DIR wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/abirate_english_quotes -P $DATASET_DIR
fi fi
if [ ! -d $SPEECH_DATASET_PATH ]; then
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR
fi
- name: Run LLM inference test - name: Run LLM inference test
shell: bash shell: bash
run: | run: |
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
python -m pip install expecttest einops python -m pip install expecttest einops librosa
bash python/llm/test/run-llm-inference-tests-gpu.sh bash python/llm/test/run-llm-inference-tests-gpu.sh
- name: Run LLM example tests - name: Run LLM example tests

View file

@ -15,10 +15,10 @@
# #
import os import os, time
import pytest import pytest
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSpeechSeq2Seq
from transformers import LlamaTokenizer, AutoTokenizer from transformers import LlamaTokenizer, AutoTokenizer
device = os.environ['DEVICE'] device = os.environ['DEVICE']
@ -41,14 +41,33 @@ def test_completion(Model, Tokenizer, model_path, prompt, answer):
load_in_4bit=True, load_in_4bit=True,
optimize_model=True, optimize_model=True,
trust_remote_code=True) trust_remote_code=True)
model = model.to(device) # deallocate gpu memory model = model.to(device)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output = model.generate(input_ids, max_new_tokens=32) output = model.generate(input_ids, max_new_tokens=32)
model.to('cpu') model.to('cpu') # deallocate gpu memory
output_str = tokenizer.decode(output[0], skip_special_tokens=True) output_str = tokenizer.decode(output[0], skip_special_tokens=True)
assert answer in output_str assert answer in output_str
def test_transformers_auto_model_for_speech_seq2seq_int4():
from transformers import WhisperProcessor
from datasets import load_from_disk
model_path = os.environ.get('WHISPER_TINY_ORIGIN_PATH')
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
processor = WhisperProcessor.from_pretrained(model_path)
ds = load_from_disk(dataset_path)
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True, optimize_model=True)
model = model.to(device)
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
model.to('cpu')
print('Output:', transcription)
assert 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])

View file

@ -17,7 +17,7 @@ if [ -z "$THREAD_NUM" ]; then
THREAD_NUM=2 THREAD_NUM=2
fi fi
export OMP_NUM_THREADS=$THREAD_NUM export OMP_NUM_THREADS=$THREAD_NUM
pytest ${LLM_INFERENCE_TEST_DIR} -v -s pytest ${LLM_INFERENCE_TEST_DIR} -v -s
now=$(date "+%s") now=$(date "+%s")
time=$((now-start)) time=$((now-start))