From ea55235cbd4f7bf4ea30938c25cba44fa4c25165 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:06:27 +0800 Subject: [PATCH] [NPU] Support glm-edge models (#12511) --- .../HF-Transformers-AutoModels/LLM/README.md | 12 ++ .../NPU/HF-Transformers-AutoModels/LLM/glm.py | 123 +++++++++++++ .../src/ipex_llm/transformers/npu_model.py | 17 ++ .../transformers/npu_models/glm_edge.py | 167 ++++++++++++++++++ python/llm/src/ipex_llm/utils/__init__.py | 5 +- 5 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index defed0bd..b1ffab77 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -11,6 +11,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr | Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | | Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) | | Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) | +| GLM-Edge | [THUDM/glm-edge-1.5b-chat](https://huggingface.co/THUDM/glm-edge-1.5b-chat), [THUDM/glm-edge-4b-chat](https://huggingface.co/THUDM/glm-edge-4b-chat) | | Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | | Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | | MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) | @@ -38,6 +39,9 @@ pip install --pre --upgrade ipex-llm[npu] :: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct pip install transformers==4.45.0 accelerate==0.33.0 + +:: [optional] for glm-edge-1.5b-chat & glm-edge-4b-chat +pip install transformers==4.47.0 accelerate==0.26.0 ``` ## 2. Runtime Configurations @@ -94,6 +98,8 @@ The examples below show how to run the **_optimized HuggingFace model implementa - [Qwen2.5-7B](./qwen.py) - [MiniCPM-1B](./minicpm.py) - [MiniCPM-2B](./minicpm.py) +- [GLM-Edge-1.5B-Chat](./glm.py) +- [GLM-Edge-4B-Chat](./glm.py) - [Baichuan2-7B](./baichuan2.py) ### Run @@ -125,6 +131,12 @@ python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-1B-sft-bf16" --save-d :: to run MiniCPM-2B-sft-bf16 python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" --save-directory +:: to run glm-edge-1.5b-chat +python glm.py --repo-id-or-model-path "THUDM/glm-edge-1.5b-chat" --save-directory + +:: to run glm-edge-4b-chat +python glm.py --repo-id-or-model-path "THUDM/glm-edge-4b-chat" --save-directory + :: to run Baichuan2-7B-Chat python baichuan2.py --repo-id-or-model-path "baichuan-inc/Baichuan2-7B-Chat" --save-directory ``` diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py new file mode 100644 index 00000000..637f612a --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py @@ -0,0 +1,123 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +import time +import argparse + +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer, TextStreamer + +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Predict Tokens using `generate()` API for npu model" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default="THUDM/glm-edge-1.5b-chat", + help="The huggingface repo id for the glm-edge model to be downloaded" + ", or the path to the huggingface checkpoint folder", + ) + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") + parser.add_argument("--max-context-len", type=int, default=1024) + parser.add_argument("--max-prompt-len", type=int, default=512) + parser.add_argument('--low-bit', type=str, default="sym_int4", + help='Load in low bit to use') + parser.add_argument("--disable-streaming", action="store_true", default=False) + parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) + parser.add_argument("--save-directory", type=str, + required=True, + help="The path of folder to save converted model, " + "If path not exists, lowbit model will be saved there. " + "Else, lowbit model will be loaded.", + ) + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + if not os.path.exists(args.save_directory): + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16, + trust_remote_code=True, + attn_implementation="eager", + load_in_low_bit=args.low_bit, + optimize_model=True, + max_context_len=args.max_context_len, + max_prompt_len=args.max_prompt_len, + transpose_value_cache=not args.disable_transpose_value_cache, + save_directory=args.save_directory + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer.save_pretrained(args.save_directory) + else: + model = AutoModelForCausalLM.load_low_bit( + args.save_directory, + attn_implementation="eager", + torch_dtype=torch.float16, + optimize_model=True, + max_context_len=args.max_context_len, + max_prompt_len=args.max_prompt_len, + transpose_value_cache=not args.disable_transpose_value_cache, + ) + tokenizer = AutoTokenizer.from_pretrained(args.save_directory, trust_remote_code=True) + + if args.disable_streaming: + streamer = None + else: + streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True) + + print("-" * 80) + print("done") + with torch.inference_mode(): + print("finish to load") + for i in range(3): + message = [{"role": "user", "content": args.prompt}] + + inputs = tokenizer.apply_chat_template( + message, + return_tensors="pt", + add_generation_prompt=True, + return_dict=True, + ) + _input_ids = inputs["input_ids"] + + print("-" * 20, "Input", "-" * 20) + print("input length:", len(_input_ids[0])) + input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False) + print(input_str) + print("-" * 20, "Output", "-" * 20) + st = time.time() + output = model.generate( + _input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict, streamer=streamer + ) + end = time.time() + if args.disable_streaming: + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print(output_str) + print(f"Inference time: {end-st} s") + + print("-" * 80) + print("done") + print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index b2b08fe2..673a56ec 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -180,6 +180,23 @@ class _BaseAutoModelClass: logger.info(f"Converting model, it may takes up to several minutes ...") + if hasattr(model, "config") and model.config.model_type == "glm": + # convert to llama structure + from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict + import json + original_path = model.config._name_or_path + del model + + with open(os.path.join(original_path, "config.json")) as f: + original_config = json.load(f) + config = convert_config(original_config) + original_state_dict = load_weights(original_path) + new_dict, _ = convert_state_dict(original_state_dict, config, + original_config.get("partial_rotary_factor", 1.0), + decouple_tied_embeddings=False) + torch.set_default_dtype(config.torch_dtype) + model = cls.HF_Model.from_pretrained(original_path, config=config, state_dict=new_dict) + if hasattr(model, "config"): model.config.update({"optimize_model": optimize_model}) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py b/python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py new file mode 100644 index 00000000..b65339a3 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py @@ -0,0 +1,167 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +import torch +from safetensors.torch import load_file +from tokenizers import processors + +from transformers import LlamaConfig, PreTrainedTokenizerFast +from ipex_llm.transformers.utils import invalidInputError + + +VIT_KEY = "vit_path" # FIXME: just made at random +VIT_FILE = "vit_adapter.pt" + + +def load_weights(input_dir: str): + safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) + if x.endswith(".safetensors")] + bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")] + + all_weights = {} + + if safetensor_files: + if len(safetensor_files) > 1: + safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1])) + for file in safetensor_files: + tensors = load_file(file) + all_weights.update(tensors) + return all_weights + + elif bin_files: + if len(bin_files) > 1: + bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1])) + for file in bin_files: + tensors = torch.load(file, map_location="cpu") + all_weights.update(tensors) + return all_weights + + else: + invalidInputError(False, "No .safetensors or .bin files found in the specified directory.") + + +def convert_state_dict(original_state_dict: dict, config: LlamaConfig, + partial_rotary_factor: float, decouple_tied_embeddings=False): + hidden_size, num_heads = config.hidden_size, config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + head_dim = hidden_size // num_heads + rotary_dim = int(partial_rotary_factor * head_dim) + inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + + # permute for sliced rotary + def permute_weight(w, num_heads, rotary_dim): + w = w.view(num_heads, head_dim, hidden_size) + w, w_pass = w[:, :rotary_dim, :], w[:, rotary_dim:, :] + w = w.view(num_heads, rotary_dim // 2, 2, hidden_size).transpose(1, 2)\ + .reshape(num_heads, rotary_dim, hidden_size) + return torch.cat([w, w_pass], dim=1).view(num_heads * head_dim, hidden_size) + + def permute_bias(b, num_heads, rotary_dim): + b = b.view(num_heads, head_dim) + b, b_pass = b[:, :rotary_dim], b[:, rotary_dim:] + b = b.view(num_heads, rotary_dim // 2, 2).transpose(1, 2).reshape(num_heads, rotary_dim) + return torch.cat([b, b_pass], dim=1).view(num_heads * head_dim) + + new_dict, vit_dict = {}, {} + param_count = 0 + index_dict = {"weight_map": {}} + for key, value in original_state_dict.items(): + if "model.vision" in key: # vit + vit_dict[key.replace("model.vision.", "")] = value.detach().clone() + elif "q_proj." in key: + if "weight" in key: + new_dict[key] = permute_weight(value, num_heads, rotary_dim) + elif config.attention_bias: # bias + new_dict[key] = permute_bias(value, num_heads, rotary_dim) + elif "k_proj." in key: + if "weight" in key: + new_dict[key] = permute_weight(value, num_key_value_heads, rotary_dim) + elif config.attention_bias: # bias + new_dict[key] = permute_bias(value, num_key_value_heads, rotary_dim) + elif "v_proj." in key: + if "bias" in key and not config.attention_bias: + continue + new_dict[key] = value + elif "o_proj." in key: + new_dict[key] = value + if config.attention_bias: # bias + new_dict[key.replace("weight", "bias")] = torch.zeros(hidden_size, + dtype=value.dtype) + elif "gate_up_proj." in key: + gate_proj, up_proj = value.chunk(2, dim=0) + new_dict[key.replace("gate_up_proj.", "gate_proj.")] = gate_proj + new_dict[key.replace("gate_up_proj.", "up_proj.")] = up_proj + else: + new_dict[key] = value + + for layer_i in range(config.num_hidden_layers): + new_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq.clone() + + if decouple_tied_embeddings: + new_dict["transformer.output_layer.weight"] = \ + original_state_dict["model.embed_tokens.weight"].clone() + + return new_dict, vit_dict + + +def convert_config(original_config: dict, decouple_tied_embeddings=False): + similar_keys_to_keep = [ + "num_attention_heads", + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "rms_norm_eps", + "num_key_value_heads", + "vocab_size", + "partial_rotary_factor", + "rope_theta", + "max_position_embeddings", + "attention_bias", + "torch_dtype", + "tie_word_embeddings", + "bos_token_id", + "eos_token_id", + "pad_token_id", + "boi_token_id", + "eoi_token_id", + "vision_config", + ] + new_config_kwargs = {k: v for k, v in original_config.items() if k in similar_keys_to_keep} + if getattr(original_config, "partial_rotary_factor", 1) < 1: + new_config_kwargs["rope_dim"] = original_config["head_dim"] * \ + original_config["partial_rotary_factor"] + if decouple_tied_embeddings: + new_config_kwargs["tie_word_embeddings"] = False + if "vision_config" in original_config: + new_config_kwargs["vision_config"] = original_config["vision_config"] + new_config_kwargs[VIT_KEY] = VIT_FILE + if "bos_token_id" not in new_config_kwargs: + new_config_kwargs["bos_token_id"] = None + + new_config = LlamaConfig(**new_config_kwargs) + return new_config + + +def convert_glm_tokenizer(input_dir): + fast_tok = PreTrainedTokenizerFast.from_pretrained(input_dir, + model_input_names=["input_ids", + "attention_mask"]) + fast_tok._tokenizer.post_processor = processors.Sequence( + [processors.ByteLevel(trim_offsets=False)], + ) + return fast_tok diff --git a/python/llm/src/ipex_llm/utils/__init__.py b/python/llm/src/ipex_llm/utils/__init__.py index 4fed203d..aecb58b2 100644 --- a/python/llm/src/ipex_llm/utils/__init__.py +++ b/python/llm/src/ipex_llm/utils/__init__.py @@ -22,7 +22,10 @@ import transformers trans_version = transformers.__version__ -if trans_version >= "4.45.0": +if trans_version >= "4.47.0": + # TODO + pass +elif trans_version >= "4.45.0": from .benchmark_util_4_45 import BenchmarkWrapper elif trans_version >= "4.44.0": from .benchmark_util_4_44 import BenchmarkWrapper