diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md index d30ce356..2e230e96 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md @@ -8,6 +8,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr |------------|----------------------------------------------------------------| | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | +| Llama3.2 | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | | Qwen2 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | | Qwen2.5 | [Qwen/Qwen2.5-7b-Instruct](https://huggingface.co/Qwen/Qwen2.5-7b-Instruct) | | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) | @@ -28,6 +29,9 @@ conda activate llm :: install ipex-llm with 'npu' option pip install --pre --upgrade ipex-llm[npu] + +:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct +pip install transformers==4.45.0 accelerate==0.33.0 ``` ## 2. Runtime Configurations @@ -48,6 +52,12 @@ python llama2.py :: to run Meta-Llama-3-8B-Instruct python llama3.py +:: to run Llama-3.2-1B-Instruct +python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-1B-Instruct" + +:: to run Llama-3.2-3B-Instruct +python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-3B-Instruct" + :: to run Qwen2.5-7b-Instruct python qwen.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index f595d111..cf452199 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -124,11 +124,12 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): if self.cached_cos is None: if mode == "prefill": position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) - self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) - self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) - else: - self.cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) - self.sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim)) + cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), + dtype=np.float32) + self.cos = self.convert_to_fp16(cos) + sin = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), + dtype=np.float32) + self.sin = self.convert_to_fp16(sin) else: position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) cos = self.constant(self.cached_cos) @@ -367,7 +368,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): ) if self.cached_cos is None: - inputs += (cos.to(torch.float16), sin.to(torch.float16)) + inputs += (cos.to(torch.float32), sin.to(torch.float32)) else: inputs += (position_ids.to(torch.int64),) @@ -496,7 +497,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): attention_mask.to(torch.int64), position_ids.to(torch.int64)) if self.cached_cos is None: - inputs += (cos.to(torch.float16), sin.to(torch.float16),) + inputs += (cos.to(torch.float32), sin.to(torch.float32),) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 2997c0e8..13997d9f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -54,8 +54,7 @@ def run_model( # Reshape input input_dtype = x[0].dtype - x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else - set_contiguous(elem).to(torch.float16).numpy() for elem in x] + x_np = [set_contiguous(elem).numpy() for elem in x] op_args = [] op_args_flatten = [] for w in weights: @@ -651,8 +650,7 @@ class LLMBaseNNFactory(NNFactory): @staticmethod def run_decoders(inputs, decoders, models_ptr=None): - x_np = [elem.numpy() if elem.dtype == torch.int64 else - elem.to(torch.float16).numpy() for elem in inputs] + x_np = [elem.numpy() for elem in inputs] num_decoders = len(decoders) num_inputs = len(x_np) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 4d3e8d3e..47f6732a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -233,8 +233,12 @@ def convert_llm(model: torch.nn.Module, model.num_layers = layer_num model.transpose_value_cache = transpose_value_cache + if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): + model_type = "llama" + else: + model_type = "llama_32" try: - res = InitLLMPipeline("llama", kv_len, model.num_head, model.head_dim, layer_num, + res = InitLLMPipeline(model_type, kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", first_blob_path, last_blob_path, os.path.join(temp_dir, "decoder_layer"), layernorm_const) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 09faad35..4e3674a7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -19,6 +19,68 @@ import torch import numpy as np import os from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead +from intel_npu_acceleration_library.backend.factory import NNFactory + + +class Llama32Embedding(NNFactory): + def __init__( + self, + vocab_size, + embedding_dim, + embedding_weight, + padding_idx, + inv_freq, + attention_scaling, + dtype, # fp16 + device: str = "NPU", + ): + super().__init__(False, device) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.attention_scaling = attention_scaling + self.dtype = dtype + + # define input + weight = self.constant(embedding_weight) + input = self.parameter((1, 1), dtype=np.int32) + position_ids = self.parameter((1, 1), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # embed_tokens module + if padding_idx == -1: + padding_idx += vocab_size + + axis_node = self.constant(np.array([0], dtype=np.int64)) + if padding_idx is not None: + masked_embeddings = np.ones(weight.shape, dtype=np.float16) + masked_embeddings[padding_idx, :] = 0.0 # mask + + node_mask = self.constant(masked_embeddings) + node_masked_w = self.eltwise_mul(weight, node_mask) + res = self.gather(node_masked_w, input, axis_node, 0) + else: + res = self.gather(weight, input, axis_node, 0) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, 1)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + # define outputs + res = self.convert_to_fp16(res) + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): @@ -71,14 +133,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") weight.tofile(bin_file) - embedding_layer = model.model.embed_tokens - new_embedding = LLMEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - ) + if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + embedding_layer = model.model.embed_tokens + new_embedding = LLMEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + else: + # llama-3.2-3B & llama-3.2-1B + new_embedding = Llama32Embedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), + attention_scaling=model.model.rotary_emb.attention_scaling, + dtype=np.float16, + ) first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", temp_dir) return first_blob_path, last_blob_path @@ -135,8 +210,14 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + # llama-3.2-3B & llama-3.2-1B + cached_cos = None + cached_sin = None layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) @@ -168,14 +249,26 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, f"decoder_layer_{layer_idx}", temp_dir) - if layernorm_const: - st_idx = 5 + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + if layernorm_const: + st_idx = 5 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 7 else: - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - st_idx = 7 + # llama-3.2-3B & llama-3.2-1B + if layernorm_const: + st_idx = 6 + else: + input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") + post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_5.bin") + layer_norm_0.data.numpy().tofile(input_lm_bin_file) + layer_norm_1.data.numpy().tofile(post_lm_bin_file) + st_idx = 8 for idx, (weight, scale) in enumerate(weights): bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") weight.numpy().tofile(bin_file)