diff --git a/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/README.md b/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/README.md index de12ffc7..ab5943fb 100644 --- a/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/README.md +++ b/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/README.md @@ -19,7 +19,10 @@ pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-w pip install librosa soundfile datasets pip install accelerate pip install SpeechRecognition sentencepiece colorama +# If you failed to install PyAudio, try to run sudo apt install portaudio19-dev on ubuntu +pip install PyAudio inquirer sounddevice ``` + ### 2. Configures OneAPI environment variables ```bash source /opt/intel/oneapi/setvars.sh @@ -44,4 +47,55 @@ Arguments info: - `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. #### Sample Output -Should be tested on a linux machine with microphone. +```bash +(llm) bigdl@bigdl-llm:~/Documents/voiceassistant$ python generate.py --llama2-repo-id-or-model-path /mnt/windows/demo/models/Llama-2-7b-chat-hf --whisper-repo-id-or-model-path /mnt/windows/demo/models/whisper-medium +/home/bigdl/anaconda3/envs/llm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source? + warn( + +[?] Which microphone do you choose?: Default + > Default + HDA Intel PCH: ALC274 Analog (hw:0,0) + HDA Intel PCH: HDMI 0 (hw:0,3) + HDA Intel PCH: HDMI 1 (hw:0,7) + HDA Intel PCH: HDMI 2 (hw:0,8) + HDA Intel PCH: HDMI 3 (hw:0,9) + HDA Intel PCH: HDMI 4 (hw:0,10) + HDA Intel PCH: HDMI 5 (hw:0,11) + HDA Intel PCH: HDMI 6 (hw:0,12) + HDA Intel PCH: HDMI 7 (hw:0,13) + HDA Intel PCH: HDMI 8 (hw:0,14) + HDA Intel PCH: HDMI 9 (hw:0,15) + HDA Intel PCH: HDMI 10 (hw:0,16) + +The device name Default is selected. +Downloading builder script: 100%|██████████████████████████████████████████████████████| 5.17k/5.17k [00:00<00:00, 14.3MB/s] +Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████| 9.08M/9.08M [00:01<00:00, 4.75MB/s] +Downloading data files: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00, 4.57s/it]] +Extracting data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 39.98it/s] +Generating validation split: 73 examples [00:00, 5328.37 examples/s] +Converting and loading models... +Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00, 3.04s/it] +/home/bigdl/anaconda3/envs/yina-llm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. + warnings.warn( +/home/bigdl/anaconda3/envs/yina-llm/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. + warnings.warn( +/home/bigdl/anaconda3/envs/yina-llm/lib/python3.9/site-packages/transformers/generation/utils.py:1411: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation ) + warnings.warn( +Calibrating... +Listening now... +Recognizing... + +Whisper : + What is AI? + +BigDL-LLM: + Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence, +Listening now... +Recognizing... + +Whisper : + Tell me something about Intel + +BigDL-LLM: + Intel is a well-known technology company that specializes in designing, manufacturing, and selling computer hardware components and semiconductor products. +``` diff --git a/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/generate.py b/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/generate.py index 568d2fea..e669af29 100644 --- a/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/generate.py +++ b/python/llm/example/transformers/transformers_int4/GPU/voiceassistant/generate.py @@ -17,13 +17,15 @@ import os import torch import time +import intel_extension_for_pytorch as ipex import argparse import numpy as np +import inquirer +import sounddevice from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForSpeechSeq2Seq from transformers import LlamaTokenizer -import intel_extension_for_pytorch as ipex from transformers import WhisperProcessor from transformers import TextStreamer from colorama import Fore @@ -49,27 +51,6 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], texts.append(f'{message} [/INST]') return ''.join(texts) -def get_input_features(r): - with sr.Microphone(device_index=1, sample_rate=16000) as source: - print("Calibrating...") - r.adjust_for_ambient_noise(source, duration=5) - - print(Fore.YELLOW + "Listening now..." + Fore.RESET) - try: - audio = r.listen(source, timeout=5, phrase_time_limit=30) - # refer to https://github.com/openai/whisper/blob/main/whisper/audio.py#L63 - frame_data = np.frombuffer(audio.frame_data, np.int16).flatten().astype(np.float32) / 32768.0 - input_features = processor(frame_data, sampling_rate=audio.sample_rate, return_tensors="pt").input_features - input_features = input_features.half().contiguous().to('xpu') - print("Recognizing...") - except Exception as e: - unrecognized_speech_text = ( - f"Sorry, I didn't catch that. Exception was: \n {e}" - ) - print(unrecognized_speech_text) - - return input_features - if __name__ == '__main__': parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') parser.add_argument('--llama2-repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", @@ -82,6 +63,21 @@ if __name__ == '__main__': help='Max tokens to predict') args = parser.parse_args() + + # Select device + mics = sr.Microphone.list_microphone_names() + mics.insert(0, "Default") + questions = [ + inquirer.List('device_name', + message="Which microphone do you choose?", + choices=mics) + ] + answers = inquirer.prompt(questions) + device_name = answers['device_name'] + idx = mics.index(device_name) + device_index = None if idx == 0 else idx - 1 + print(f"The device name {device_name} is selected.") + whisper_model_path = args.whisper_repo_id_or_model_path llama_model_path = args.llama2_repo_id_or_model_path @@ -95,10 +91,10 @@ if __name__ == '__main__': # generate token ids whisper = AutoModelForSpeechSeq2Seq.from_pretrained(whisper_model_path, load_in_4bit=True, optimize_model=False) whisper.config.forced_decoder_ids = None - whisper = whisper.half().to('xpu') + whisper = whisper.to('xpu') llama_model = AutoModelForCausalLM.from_pretrained(llama_model_path, load_in_4bit=True, trust_remote_code=True, optimize_model=False) - llama_model = llama_model.half().to('xpu') + llama_model = llama_model.to('xpu') tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) r = sr.Recognizer() @@ -107,7 +103,7 @@ if __name__ == '__main__': # warm up sample = ds[2]["audio"] input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features - input_features = input_features.half().contiguous().to('xpu') + input_features = input_features.contiguous().to('xpu') torch.xpu.synchronize() predicted_ids = whisper.generate(input_features) torch.xpu.synchronize() @@ -117,15 +113,32 @@ if __name__ == '__main__': output = llama_model.generate(input_ids, do_sample=False, max_new_tokens=32) output_str = tokenizer.decode(output[0], skip_special_tokens=True) torch.xpu.synchronize() - - while 1: - input_features = get_input_features(r) - predicted_ids = whisper.generate(input_features) - output_str = processor.batch_decode(predicted_ids, skip_special_tokens=True) - output_str = output_str[0] - print("\n" + Fore.GREEN + "Whisper : " + Fore.RESET + "\n" + output_str) - print("\n" + Fore.BLUE + "BigDL-LLM: " + Fore.RESET) - prompt = get_prompt(output_str, [], system_prompt=DEFAULT_SYSTEM_PROMPT) - input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu') - streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) - _ = llama_model.generate(input_ids, streamer=streamer, do_sample=False, max_new_tokens=args.n_predict) + + with sr.Microphone(device_index=device_index, sample_rate=16000) as source: + print("Calibrating...") + r.adjust_for_ambient_noise(source, duration=5) + + while 1: + print(Fore.YELLOW + "Listening now..." + Fore.RESET) + try: + audio = r.listen(source, timeout=5, phrase_time_limit=30) + # refer to https://github.com/openai/whisper/blob/main/whisper/audio.py#L63 + frame_data = np.frombuffer(audio.frame_data, np.int16).flatten().astype(np.float32) / 32768.0 + print("Recognizing...") + input_features = processor(frame_data, sampling_rate=audio.sample_rate, return_tensors="pt").input_features + input_features = input_features.contiguous().to('xpu') + except Exception as e: + unrecognized_speech_text = ( + f"Sorry, I didn't catch that. Exception was: \n {e}" + ) + print(unrecognized_speech_text) + + predicted_ids = whisper.generate(input_features) + output_str = processor.batch_decode(predicted_ids, skip_special_tokens=True) + output_str = output_str[0] + print("\n" + Fore.GREEN + "Whisper : " + Fore.RESET + "\n" + output_str) + print("\n" + Fore.BLUE + "BigDL-LLM: " + Fore.RESET) + prompt = get_prompt(output_str, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu') + streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + _ = llama_model.generate(input_ids, streamer=streamer, do_sample=False, max_new_tokens=args.n_predict)