Add inference test for Whisper model on Arc (#9330)
* Add inference test for Whisper model * Remove unnecessary inference time measurement
This commit is contained in:
parent
63411dff75
commit
8f23fb04dc
3 changed files with 35 additions and 6 deletions
12
.github/workflows/llm_unit_tests.yml
vendored
12
.github/workflows/llm_unit_tests.yml
vendored
|
|
@ -221,11 +221,13 @@ jobs:
|
|||
run: |
|
||||
echo "DATASET_DIR=${ORIGIN_DIR}/../datasets" >> "$GITHUB_ENV"
|
||||
echo "ABIRATE_ENGLISH_QUOTES_PATH=${ORIGIN_DIR}/../datasets/abirate_english_quotes" >> "$GITHUB_ENV"
|
||||
echo "SPEECH_DATASET_PATH=${ORIGIN_DIR}/../datasets/librispeech_asr_dummy" >> "$GITHUB_ENV"
|
||||
|
||||
echo "LLAMA2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Llama-2-7b-chat-hf" >> "$GITHUB_ENV"
|
||||
echo "CHATGLM2_6B_ORIGIN_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
|
||||
echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV"
|
||||
echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV"
|
||||
echo "WHISPER_TINY_ORIGIN_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v3
|
||||
|
|
@ -275,6 +277,10 @@ jobs:
|
|||
echo "Directory $MPT_7B_ORIGIN_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/mpt-7b-chat -P $ORIGIN_DIR
|
||||
fi
|
||||
if [ ! -d $WHISPER_TINY_ORIGIN_PATH ]; then
|
||||
echo "Directory $WHISPER_TINY_ORIGIN_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
|
||||
fi
|
||||
if [ ! -d $DATASET_DIR ]; then
|
||||
mkdir -p $DATASET_DIR
|
||||
fi
|
||||
|
|
@ -282,12 +288,16 @@ jobs:
|
|||
echo "Directory $ABIRATE_ENGLISH_QUOTES_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/abirate_english_quotes -P $DATASET_DIR
|
||||
fi
|
||||
if [ ! -d $SPEECH_DATASET_PATH ]; then
|
||||
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
|
||||
wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR
|
||||
fi
|
||||
|
||||
- name: Run LLM inference test
|
||||
shell: bash
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
python -m pip install expecttest einops
|
||||
python -m pip install expecttest einops librosa
|
||||
bash python/llm/test/run-llm-inference-tests-gpu.sh
|
||||
|
||||
- name: Run LLM example tests
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@
|
|||
#
|
||||
|
||||
|
||||
import os
|
||||
import os, time
|
||||
import pytest
|
||||
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSpeechSeq2Seq
|
||||
from transformers import LlamaTokenizer, AutoTokenizer
|
||||
|
||||
device = os.environ['DEVICE']
|
||||
|
|
@ -41,14 +41,33 @@ def test_completion(Model, Tokenizer, model_path, prompt, answer):
|
|||
load_in_4bit=True,
|
||||
optimize_model=True,
|
||||
trust_remote_code=True)
|
||||
model = model.to(device) # deallocate gpu memory
|
||||
model = model.to(device)
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
output = model.generate(input_ids, max_new_tokens=32)
|
||||
model.to('cpu')
|
||||
model.to('cpu') # deallocate gpu memory
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
|
||||
assert answer in output_str
|
||||
|
||||
def test_transformers_auto_model_for_speech_seq2seq_int4():
|
||||
from transformers import WhisperProcessor
|
||||
from datasets import load_from_disk
|
||||
model_path = os.environ.get('WHISPER_TINY_ORIGIN_PATH')
|
||||
dataset_path = os.environ.get('SPEECH_DATASET_PATH')
|
||||
processor = WhisperProcessor.from_pretrained(model_path)
|
||||
ds = load_from_disk(dataset_path)
|
||||
sample = ds[0]["audio"]
|
||||
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
|
||||
input_features = input_features.to(device)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_path, trust_remote_code=True, load_in_4bit=True, optimize_model=True)
|
||||
model = model.to(device)
|
||||
predicted_ids = model.generate(input_features)
|
||||
# decode token ids to text
|
||||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
|
||||
model.to('cpu')
|
||||
print('Output:', transcription)
|
||||
assert 'Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' in transcription[0]
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue