diff --git a/.github/workflows/llm_unit_tests.yml b/.github/workflows/llm_unit_tests.yml index e6d01a6c..77797414 100644 --- a/.github/workflows/llm_unit_tests.yml +++ b/.github/workflows/llm_unit_tests.yml @@ -245,7 +245,6 @@ jobs: echo "FALCON_7B_ORIGIN_PATH=${ORIGIN_DIR}/falcon-7b-instruct-with-patch" >> "$GITHUB_ENV" echo "MPT_7B_ORIGIN_PATH=${ORIGIN_DIR}/mpt-7b-chat" >> "$GITHUB_ENV" echo "WHISPER_TINY_ORIGIN_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV" - echo "MISTRAL_7B_INSTRUCT_V0_1_ORIGIN_PATH=${ORIGIN_DIR}/Mistral-7B-Instruct-v0.1" >> "$GITHUB_ENV" echo "BAICHUAN2_7B_ORIGIN_PATH=${ORIGIN_DIR}/Baichuan2-7B-Chat" >> "$GITHUB_ENV" echo "QWEN_7B_ORIGIN_PATH=${ORIGIN_DIR}/Qwen-7B-Chat" >> "$GITHUB_ENV" diff --git a/python/llm/test/inference_gpu/test_transformers_api.py b/python/llm/test/inference_gpu/test_transformers_api.py index b8c5903b..9a9bb8a3 100644 --- a/python/llm/test/inference_gpu/test_transformers_api.py +++ b/python/llm/test/inference_gpu/test_transformers_api.py @@ -34,6 +34,9 @@ print(f'Running on {device}') (AutoModel, AutoTokenizer, os.environ.get('CHATGLM2_6B_ORIGIN_PATH')), (AutoModelForCausalLM, AutoTokenizer, os.environ.get('FALCON_7B_ORIGIN_PATH')), (AutoModelForCausalLM, AutoTokenizer, os.environ.get('MPT_7B_ORIGIN_PATH')), + # (AutoModelForCausalLM, AutoTokenizer, os.environ.get('MISTRAL_7B_INSTRUCT_V0_1_ORIGIN_PATH')), + # (AutoModelForCausalLM, AutoTokenizer, os.environ.get('BAICHUAN2_7B_ORIGIN_PATH')), + # (AutoModelForCausalLM, AutoTokenizer, os.environ.get('QWEN_7B_ORIGIN_PATH')), ]) def test_completion(Model, Tokenizer, model_path, prompt, answer): with torch.inference_mode(): diff --git a/python/llm/test/inference_gpu/test_transformers_api_RMSNorm.py b/python/llm/test/inference_gpu/test_transformers_api_RMSNorm.py new file mode 100644 index 00000000..7e8898fc --- /dev/null +++ b/python/llm/test/inference_gpu/test_transformers_api_RMSNorm.py @@ -0,0 +1,154 @@ + +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import gc +import pytest + +import torch +from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel +from transformers import LlamaTokenizer, AutoTokenizer + +device = os.environ['DEVICE'] +print(f'Running on {device}') + +PROMPT = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" +TEST_MODEL_LIST = [ + ("Llama2-7B", AutoModelForCausalLM, LlamaTokenizer, os.environ.get('LLAMA2_7B_ORIGIN_PATH')), + ("ChatGLM2-6B", AutoModel, AutoTokenizer, os.environ.get('CHATGLM2_6B_ORIGIN_PATH')), + ("Mistral-7B-Instruct-v0.1", AutoModelForCausalLM, AutoTokenizer, os.environ.get('MISTRAL_7B_INSTRUCT_V0_1_ORIGIN_PATH')), + ("Baichuan2-7B-Chat", AutoModelForCausalLM, AutoTokenizer, os.environ.get('BAICHUAN2_7B_ORIGIN_PATH')), + ("Qwen-7B-Chat", AutoModelForCausalLM, AutoTokenizer, os.environ.get('QWEN_7B_ORIGIN_PATH')), +] + +class Test_Optimize_Gpu_Model: + def setup_method(self): + self.layer_outputs = [] + self.pre_layer_outputs = [] + + def run_optimize_gpu_model(self, Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound): + with torch.inference_mode(): + def pre_forward_hook(module, input, output, layer_name): + self.pre_layer_outputs.append(output) + + def forward_hook(module, input, output, layer_name): + self.layer_outputs.append(output) + + + tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True) + input_ids = tokenizer.encode(PROMPT, return_tensors="pt").to(device) + + model = Model.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=False, + trust_remote_code=True) + + model = model.to(device) + + for layer_name, layer_module in model.named_modules(): + if layer_name == layer_before_RMSNorm: + layer_module.register_forward_hook( + lambda module, input, output, layer_name=layer_name: pre_forward_hook(module, input, + output, layer_name)) + if layer_name == RMSNorm_layer: + layer_module.register_forward_hook( + lambda module, input, output, layer_name=layer_name: forward_hook(module, input, + output, layer_name)) + logits_base_model = (model(input_ids)).logits + # the list `layer_output` has only one element. + layer_tensor = self.layer_outputs.pop() + model.to('cpu') + opt_model = Model.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=True, + trust_remote_code=True) + opt_model = opt_model.to(device) + + + def replace_forward_hook(module, input, output, layer_name): + output = self.pre_layer_outputs[0] + return output + + for layer_name, layer_module in opt_model.named_modules(): + if layer_name == layer_before_RMSNorm: + layer_module.register_forward_hook( + lambda module, input, output, layer_name=layer_name: replace_forward_hook(module, input, + output, layer_name)) + if layer_name == RMSNorm_layer: + layer_module.register_forward_hook( + lambda module, input, output, layer_name=layer_name: forward_hook(module, input, + output, layer_name)) + logits_optimized_model = (opt_model(input_ids)).logits + # the list `layer_output` has only one element. + opt_layer_tensor = self.layer_outputs[0] + opt_model.to('cpu') + + RMSNorm_output_diff = [] + for i, (t1, t2) in enumerate(zip(layer_tensor, opt_layer_tensor)): + if t1 is not None and t2 is not None: + if isinstance(t1, torch.Tensor) and isinstance(t2, torch.Tensor): + RMSNorm_output_diff.append(t1 - t2) + max_diff_tensor = [torch.max(item).item() for item in RMSNorm_output_diff] + print(max_diff_tensor) + torch.xpu.empty_cache() + del model + del opt_model + gc.collect() + assert all(max_diff <= lower_bound for max_diff in max_diff_tensor) + + @pytest.mark.parametrize('Name, Model, Tokenizer, model_path',TEST_MODEL_LIST) + def test_dynamic_functions(self, Name, Model, Tokenizer, model_path): + if Name == "Llama2-7B": + self.Llama2_7B_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "ChatGLM2-6B": + self.Chatglm2_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Mistral-7B-Instruct-v0.1": + self.Mistral_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Baichuan2-7B-Chat": + self.Baichuan_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Qwen-7B-Chat": + self.Qwen_gpu_model(Name, Model, Tokenizer, model_path) + + def Llama2_7B_gpu_model(self, Name, Model, Tokenizer, model_path): + layer_before_RMSNorm = "model.layers.30" + RMSNorm_layer = "model.layers.31.input_layernorm" + lower_bound = 1e-6 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound) + + def Chatglm2_gpu_model(self, Name, Model, Tokenizer, model_path): + layer_before_RMSNorm = "transformer.encoder.layers.26" + RMSNorm_layer = "transformer.encoder.layers.27.input_layernorm" + lower_bound = 2e-6 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound) + + def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path): + layer_before_RMSNorm = "model.layers.30" + RMSNorm_layer = "model.layers.31.input_layernorm" + lower_bound = 6e-6 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound) + + def Baichuan_gpu_model(self, Name, Model, Tokenizer, model_path): + layer_before_RMSNorm = "model.layers.30" + RMSNorm_layer = "model.layers.31.input_layernorm" + lower_bound = 5e-7 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound) + + def Qwen_gpu_model(self, Name, Model, Tokenizer, model_path): + layer_before_RMSNorm = "transformer.h.30" + RMSNorm_layer = "transformer.h.31.ln_1" + lower_bound = 2e-6 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, RMSNorm_layer, layer_before_RMSNorm, lower_bound) \ No newline at end of file diff --git a/python/llm/test/inference_gpu/test_transformers_api_attention.py b/python/llm/test/inference_gpu/test_transformers_api_attention.py index a0d1864b..51cbc0d0 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_attention.py +++ b/python/llm/test/inference_gpu/test_transformers_api_attention.py @@ -16,6 +16,7 @@ # import os +import gc import pytest import torch @@ -31,6 +32,9 @@ TEST_MODEL_LIST = [ ("Llama2-7B", AutoModelForCausalLM, LlamaTokenizer, os.environ.get('LLAMA2_7B_ORIGIN_PATH')), ("Falcon-7B", AutoModelForCausalLM, AutoTokenizer, os.environ.get('FALCON_7B_ORIGIN_PATH')), ("ChatGLM2-6B", AutoModel, AutoTokenizer, os.environ.get('CHATGLM2_6B_ORIGIN_PATH')), + ("Mistral-7B-Instruct-v0.1", AutoModelForCausalLM, AutoTokenizer, os.environ.get('MISTRAL_7B_INSTRUCT_V0_1_ORIGIN_PATH')), + ("Baichuan2-7B-Chat", AutoModelForCausalLM, AutoTokenizer, os.environ.get('BAICHUAN2_7B_ORIGIN_PATH')), + ("Qwen-7B-Chat", AutoModelForCausalLM, AutoTokenizer, os.environ.get('QWEN_7B_ORIGIN_PATH')), ] class Test_Optimize_Gpu_Model: @@ -113,6 +117,10 @@ class Test_Optimize_Gpu_Model: max_diff_tensor = [torch.max(item).item() for item in attn_output_diff] print(max_diff_tensor) + torch.xpu.empty_cache() + del model + del opt_model + gc.collect() assert all(max_diff <= lower_bound for max_diff in max_diff_tensor) @@ -126,6 +134,12 @@ class Test_Optimize_Gpu_Model: self.Falcon_7B_gpu_model(Name, Model, Tokenizer, model_path) elif Name == "ChatGLM2-6B": self.Chatglm2_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Mistral-7B-Instruct-v0.1": + self.Mistral_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Baichuan2-7B-Chat": + self.Baichuan_gpu_model(Name, Model, Tokenizer, model_path) + elif Name == "Qwen-7B-Chat": + self.Qwen_gpu_model(Name, Model, Tokenizer, model_path) def MPT_7B_gpu_model(self, Name, Model, Tokenizer, model_path): @@ -139,7 +153,7 @@ class Test_Optimize_Gpu_Model: # currently only compare the output of the last self-attention layer. layer_norm = "model.layers.31.input_layernorm" self_attn = "model.layers.31.self_attn" - lower_bound = 5e-2 + lower_bound = 8e-3 self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) def Falcon_7B_gpu_model(self, Name, Model, Tokenizer, model_path): @@ -153,5 +167,26 @@ class Test_Optimize_Gpu_Model: # currently only need to compare the output of one self-attention layer. layer_norm = "transformer.encoder.layers.27.input_layernorm" self_attn = "transformer.encoder.layers.27.self_attention" - lower_bound = 5e-3 + lower_bound = 1e-3 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) + + def Mistral_gpu_model(self, Name, Model, Tokenizer, model_path): + # currently only need to compare the output of one self-attention layer. + layer_norm = "model.layers.31.input_layernorm" + self_attn = "model.layers.31.self_attn" + lower_bound = 9e-3 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) + + def Baichuan_gpu_model(self, Name, Model, Tokenizer, model_path): + # currently only need to compare the output of one self-attention layer. + layer_norm = "model.layers.31.input_layernorm" + self_attn = "model.layers.31.self_attn" + lower_bound = 2e-3 + self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) + + def Qwen_gpu_model(self, Name, Model, Tokenizer, model_path): + # currently only need to compare the output of one self-attention layer. + layer_norm = "transformer.h.31.ln_1" + self_attn = "transformer.h.31.attn" + lower_bound = 8e-3 self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, self_attn, layer_norm, lower_bound) \ No newline at end of file diff --git a/python/llm/test/inference_gpu/test_transformers_api_final_logits.py b/python/llm/test/inference_gpu/test_transformers_api_final_logits.py index 7ff2aa78..02e3cc27 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_final_logits.py +++ b/python/llm/test/inference_gpu/test_transformers_api_final_logits.py @@ -16,6 +16,7 @@ import os +import gc import pytest import torch @@ -48,20 +49,23 @@ def test_optimize_model(Name, Model, Tokenizer, model_path): model.to('cpu') # deallocate gpu memory - model = Model.from_pretrained(model_path, + opt_model = Model.from_pretrained(model_path, load_in_4bit=True, optimize_model=True, trust_remote_code=True) - model = model.to(device) - logits_optimized_model = (model(input_ids)).logits - model.to('cpu') + opt_model = opt_model.to(device) + logits_optimized_model = (opt_model(input_ids)).logits + opt_model.to('cpu') tol = 1e-03 num_false = torch.isclose(logits_optimized_model, logits_base_model, rtol=tol, atol=tol)\ .flatten().tolist().count(False) percent_false = num_false / logits_optimized_model.numel() - + torch.xpu.empty_cache() + del model + del opt_model + gc.collect() assert percent_false < 1e-02 diff --git a/python/llm/test/inference_gpu/test_transformers_api_mlp.py b/python/llm/test/inference_gpu/test_transformers_api_mlp.py index 4bba9c4c..16431cba 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_mlp.py +++ b/python/llm/test/inference_gpu/test_transformers_api_mlp.py @@ -15,6 +15,7 @@ # import os +import gc import pytest import torch @@ -103,7 +104,10 @@ class Test_Optimize_Gpu_Model: max_diff_tensor = [torch.max(item).item() for item in MLP_output_diff] print(max_diff_tensor) - + torch.xpu.empty_cache() + del model + del opt_model + gc.collect() assert all(max_diff <= lower_bound for max_diff in max_diff_tensor) @pytest.mark.parametrize('Name, Model, Tokenizer, model_path',TEST_MODEL_LIST) diff --git a/python/llm/test/run-llm-inference-tests-gpu-434.sh b/python/llm/test/run-llm-inference-tests-gpu-434.sh index a88bb637..91a1676d 100644 --- a/python/llm/test/run-llm-inference-tests-gpu-434.sh +++ b/python/llm/test/run-llm-inference-tests-gpu-434.sh @@ -18,7 +18,9 @@ start=$(date "+%s") # fi # export OMP_NUM_THREADS=$THREAD_NUM export BIGDL_LLM_XMX_DISABLED=1 +pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_attention.py -v -s -k "Mistral" pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_mlp.py -v -s -k "Mistral" +pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_RMSNorm.py -v -s -k "Mistral" unset BIGDL_LLM_XMX_DISABLED now=$(date "+%s") diff --git a/python/llm/test/run-llm-inference-tests-gpu.sh b/python/llm/test/run-llm-inference-tests-gpu.sh index 351b6ec7..130d58d3 100644 --- a/python/llm/test/run-llm-inference-tests-gpu.sh +++ b/python/llm/test/run-llm-inference-tests-gpu.sh @@ -20,8 +20,9 @@ start=$(date "+%s") pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api.py -v -s export BIGDL_LLM_XMX_DISABLED=1 pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_final_logits.py -v -s -pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_attention.py -v -s +pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_attention.py -v -s -k "not Mistral" pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_mlp.py -v -s -k "not Mistral" +pytest ${LLM_INFERENCE_TEST_DIR}/test_transformers_api_RMSNorm.py -v -s -k "not Mistral" unset BIGDL_LLM_XMX_DISABLED now=$(date "+%s")