* Rename bigdl/llm to ipex_llm * rm python/llm/src/bigdl * from bigdl.llm to from ipex_llm
98 lines
No EOL
4 KiB
Python
98 lines
No EOL
4 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from datasets import load_dataset
|
|
from ipex_llm.transformers import AutoModelForSpeechSeq2Seq
|
|
from transformers import WhisperProcessor
|
|
import torch
|
|
from evaluate import load
|
|
import time
|
|
import argparse
|
|
import pandas as pd
|
|
import os
|
|
import csv
|
|
from datetime import date
|
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description="Evaluate Whisper performance and accuracy")
|
|
parser.add_argument('--model_path', required=True, help='pretrained model path')
|
|
parser.add_argument('--data_type', required=True, help='clean, other')
|
|
parser.add_argument('--device', required=False, help='cpu, xpu')
|
|
parser.add_argument('--load_in_low_bit', default='sym_int4', help='Specify whether to load data in low bit format (e.g., 4-bit)')
|
|
parser.add_argument('--save_result', action='store_true', help='Save the results to a CSV file')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
if args.device == "":
|
|
args.device = "cpu"
|
|
|
|
speech_dataset = load_dataset('./librispeech_asr.py', name=args.data_type, split='test').select(range(500))
|
|
processor = WhisperProcessor.from_pretrained(args.model_path)
|
|
forced_decoder_ids = processor.get_decoder_prompt_ids(language='en', task='transcribe')
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_path, load_in_low_bit=args.load_in_low_bit, optimize_model=True).eval().to(args.device)
|
|
model.config.forced_decoder_ids = None
|
|
|
|
def map_to_pred(batch):
|
|
audio = batch["audio"]
|
|
start_time = time.time()
|
|
input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
|
batch["reference"] = processor.tokenizer._normalize(batch['text'])
|
|
|
|
with torch.no_grad():
|
|
predicted_ids = model.generate(input_features.to(args.device), forced_decoder_ids=forced_decoder_ids, use_cache=True)[0]
|
|
if args.device == "xpu":
|
|
torch.xpu.synchronize()
|
|
|
|
infer_time = time.time() - start_time
|
|
transcription = processor.decode(predicted_ids)
|
|
batch["prediction"] = processor.tokenizer._normalize(transcription)
|
|
batch["length"] = len(audio["array"])/audio["sampling_rate"]
|
|
batch["time"] = infer_time
|
|
print(batch["reference"])
|
|
print(batch["prediction"])
|
|
return batch
|
|
|
|
result = speech_dataset.map(map_to_pred, keep_in_memory=True)
|
|
wer = load("./wer")
|
|
speech_length = sum(result["length"][1:])
|
|
prc_time = sum(result["time"][1:])
|
|
|
|
MODEL = args.model_path.split('/')[-2]
|
|
RTF = prc_time/speech_length
|
|
RTX = speech_length/prc_time
|
|
WER = 100 * wer.compute(references=result["reference"], predictions=result["prediction"])
|
|
|
|
today = date.today()
|
|
if args.save_result:
|
|
csv_name = f'{current_dir}/results/{MODEL}-{args.data_type}-{args.device}-{args.load_in_low_bit}-{today}.csv'
|
|
os.makedirs(os.path.dirname(csv_name), exist_ok=True)
|
|
with open(csv_name, mode='a', newline='') as file:
|
|
csv_writer = csv.writer(file)
|
|
file.seek(0, os.SEEK_END)
|
|
if file.tell() == 0:
|
|
csv_writer.writerow(["models","precision","WER","RTF"])
|
|
csv_writer.writerow([MODEL, args.load_in_low_bit, WER, RTF])
|
|
print(f'Results saved to {csv_name}')
|
|
|
|
print("Realtime Factor(RTF) is : %.4f" % RTF)
|
|
print("Realtime X(RTX) is : %.2f" % RTX)
|
|
print(f'WER is {WER}') |