diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index e7f84ab4..3dc327dc 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on - [Llama3-8B](./llama.py) - [Qwen2-1.5B](./qwen2.py) - [MiniCPM-1B](./minicpm.py) +- [Baichuan2-7B](./baichuan2.py) ```bash # to run Llama-2-7b-chat-hf @@ -95,6 +96,9 @@ python qwen2.py # to run MiniCPM-1B-sft-bf16 python minicpm.py + +# to run Baichuan2-7B-Chat +python baichuan2.py ``` Arguments info: diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py index a53de2ab..f3f4cb10 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py @@ -46,15 +46,15 @@ if __name__ == "__main__": parser.add_argument( "--repo-id-or-model-path", type=str, - default="meta-llama/Llama-2-7b-chat-hf", - help="The huggingface repo id for the Llama2 model to be downloaded" + default="baichuan-inc/Baichuan2-7B-Chat", + help="The huggingface repo id for the Baichuan2 model to be downloaded" ", or the path to the huggingface checkpoint folder", ) parser.add_argument('--prompt', type=str, default="What is AI?", help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-output-len", type=int, default=1024) - parser.add_argument("--max-prompt-len", type=int, default=768) + parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2) @@ -68,7 +68,7 @@ if __name__ == "__main__": trust_remote_code=True, attn_implementation="eager", load_in_low_bit="sym_int4", - enable_mp=True, + optimize_model=True, max_output_len=args.max_output_len, max_prompt_len=args.max_prompt_len, intra_pp=args.intra_pp, diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index ecfa0f69..96551d48 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -77,6 +77,8 @@ class _BaseAutoModelClass: :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``, ``'fp16'``, ``'fp32'``. Relevant low bit optimizations will be applied to the model. + :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. + Default to be ``False``. :return: a model instance """ if kwargs.get("device_map", None) not in [None, "cpu", "auto"]: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index b436c317..37767402 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -272,7 +272,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): input_2d = self.convert_to_fp16(input_2d) # attention - proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype) + proj = self.linear(input_2d, 3 * self.hidden_size, + self.hidden_size, bias=False, wt_dtype=self.dtype) # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h @@ -282,13 +283,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): proj = self.unsqueeze(proj, [1]) print("proj shape after unsqueeze", proj.shape) # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + query_states = self.reshape(proj[0, ...], [self.batch_size, + self.seq_len, self.num_heads, self.head_dim]) query_states = self.transpose(query_states, [0, 2, 1, 3]) # key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + key_states = self.reshape(proj[1, ...], [self.batch_size, + self.seq_len, self.num_heads, self.head_dim]) key_states = self.transpose(key_states, [0, 2, 1, 3]) # value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + value_states = self.reshape(proj[2, ...], [self.batch_size, + self.seq_len, self.num_heads, self.head_dim]) if self.transpose_value: value_states = self.transpose(value_states, [0, 2, 3, 1]) else: @@ -309,7 +313,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): else: value_states = self.concat(past_value, value_states, axis=-2) - attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) + attn_weight = self.matmul(query_states, key_states, False, True) / ( + math.sqrt(self.head_dim)) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) @@ -349,7 +354,8 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] # down proj - hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype) + hidden_states = self.linear(mm1, self.hidden_size, + self.intermediate_size, bias=False, wt_dtype=self.dtype) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -374,7 +380,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): cos = self.squeeze(cos) # [seq_len, dim] sin = self.squeeze(sin) # [seq_len, dim] # cos = cos[position_ids] - cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] + cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] # sin = sin[position_ids] sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim] @@ -1098,13 +1104,15 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + invalidInputError(False, "You cannot specify both decoder_input_ids\ + and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") + invalidInputError(False, "You have to specify either decoder_input_ids\ + or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 @@ -1120,7 +1128,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -1146,7 +1155,8 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing.\ + Setting `use_cache=False`..." ) use_cache = False diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 1bb6077c..3d74880b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -124,7 +124,8 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, module.MiniCPMModel, minicpm_model_forward) - elif model.config.model_type == "baichuan": + elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32: + # for Baichuan2-7B from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner decode_runner = DecodeRunner( @@ -141,8 +142,8 @@ def optimize_llm( transpose_value_cache=transpose_value_cache, ) baichuan_model_forward = gen_baichuan_fused_model_forward( - prefill_runner=prefill_runner, decode_runner=decode_runner - ) + prefill_runner=prefill_runner, decode_runner=decode_runner + ) modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) convert_forward(model, module.BaichuanModel, baichuan_model_forward)