[NPU] Support glm-edge models (#12511)
This commit is contained in:
parent
12c78978dd
commit
ea55235cbd
5 changed files with 323 additions and 1 deletions
|
|
@ -11,6 +11,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
|
||||||
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
| Llama3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) |
|
||||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
|
||||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
|
||||||
|
| GLM-Edge | [THUDM/glm-edge-1.5b-chat](https://huggingface.co/THUDM/glm-edge-1.5b-chat), [THUDM/glm-edge-4b-chat](https://huggingface.co/THUDM/glm-edge-4b-chat) |
|
||||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
|
||||||
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
|
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
|
||||||
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
|
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
|
||||||
|
|
@ -38,6 +39,9 @@ pip install --pre --upgrade ipex-llm[npu]
|
||||||
|
|
||||||
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
|
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct
|
||||||
pip install transformers==4.45.0 accelerate==0.33.0
|
pip install transformers==4.45.0 accelerate==0.33.0
|
||||||
|
|
||||||
|
:: [optional] for glm-edge-1.5b-chat & glm-edge-4b-chat
|
||||||
|
pip install transformers==4.47.0 accelerate==0.26.0
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Runtime Configurations
|
## 2. Runtime Configurations
|
||||||
|
|
@ -94,6 +98,8 @@ The examples below show how to run the **_optimized HuggingFace model implementa
|
||||||
- [Qwen2.5-7B](./qwen.py)
|
- [Qwen2.5-7B](./qwen.py)
|
||||||
- [MiniCPM-1B](./minicpm.py)
|
- [MiniCPM-1B](./minicpm.py)
|
||||||
- [MiniCPM-2B](./minicpm.py)
|
- [MiniCPM-2B](./minicpm.py)
|
||||||
|
- [GLM-Edge-1.5B-Chat](./glm.py)
|
||||||
|
- [GLM-Edge-4B-Chat](./glm.py)
|
||||||
- [Baichuan2-7B](./baichuan2.py)
|
- [Baichuan2-7B](./baichuan2.py)
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
@ -125,6 +131,12 @@ python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-1B-sft-bf16" --save-d
|
||||||
:: to run MiniCPM-2B-sft-bf16
|
:: to run MiniCPM-2B-sft-bf16
|
||||||
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" --save-directory <converted_model_path>
|
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" --save-directory <converted_model_path>
|
||||||
|
|
||||||
|
:: to run glm-edge-1.5b-chat
|
||||||
|
python glm.py --repo-id-or-model-path "THUDM/glm-edge-1.5b-chat" --save-directory <converted_model_path>
|
||||||
|
|
||||||
|
:: to run glm-edge-4b-chat
|
||||||
|
python glm.py --repo-id-or-model-path "THUDM/glm-edge-4b-chat" --save-directory <converted_model_path>
|
||||||
|
|
||||||
:: to run Baichuan2-7B-Chat
|
:: to run Baichuan2-7B-Chat
|
||||||
python baichuan2.py --repo-id-or-model-path "baichuan-inc/Baichuan2-7B-Chat" --save-directory <converted_model_path>
|
python baichuan2.py --repo-id-or-model-path "baichuan-inc/Baichuan2-7B-Chat" --save-directory <converted_model_path>
|
||||||
```
|
```
|
||||||
|
|
|
||||||
123
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py
Normal file
123
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/glm.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer, TextStreamer
|
||||||
|
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Predict Tokens using `generate()` API for npu model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id-or-model-path",
|
||||||
|
type=str,
|
||||||
|
default="THUDM/glm-edge-1.5b-chat",
|
||||||
|
help="The huggingface repo id for the glm-edge model to be downloaded"
|
||||||
|
", or the path to the huggingface checkpoint folder",
|
||||||
|
)
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||||
|
parser.add_argument("--max-context-len", type=int, default=1024)
|
||||||
|
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||||
|
parser.add_argument('--low-bit', type=str, default="sym_int4",
|
||||||
|
help='Load in low bit to use')
|
||||||
|
parser.add_argument("--disable-streaming", action="store_true", default=False)
|
||||||
|
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||||
|
parser.add_argument("--save-directory", type=str,
|
||||||
|
required=True,
|
||||||
|
help="The path of folder to save converted model, "
|
||||||
|
"If path not exists, lowbit model will be saved there. "
|
||||||
|
"Else, lowbit model will be loaded.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
||||||
|
if not os.path.exists(args.save_directory):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation="eager",
|
||||||
|
load_in_low_bit=args.low_bit,
|
||||||
|
optimize_model=True,
|
||||||
|
max_context_len=args.max_context_len,
|
||||||
|
max_prompt_len=args.max_prompt_len,
|
||||||
|
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||||
|
save_directory=args.save_directory
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
tokenizer.save_pretrained(args.save_directory)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.load_low_bit(
|
||||||
|
args.save_directory,
|
||||||
|
attn_implementation="eager",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
optimize_model=True,
|
||||||
|
max_context_len=args.max_context_len,
|
||||||
|
max_prompt_len=args.max_prompt_len,
|
||||||
|
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.save_directory, trust_remote_code=True)
|
||||||
|
|
||||||
|
if args.disable_streaming:
|
||||||
|
streamer = None
|
||||||
|
else:
|
||||||
|
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
print("done")
|
||||||
|
with torch.inference_mode():
|
||||||
|
print("finish to load")
|
||||||
|
for i in range(3):
|
||||||
|
message = [{"role": "user", "content": args.prompt}]
|
||||||
|
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
message,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
_input_ids = inputs["input_ids"]
|
||||||
|
|
||||||
|
print("-" * 20, "Input", "-" * 20)
|
||||||
|
print("input length:", len(_input_ids[0]))
|
||||||
|
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
|
||||||
|
print(input_str)
|
||||||
|
print("-" * 20, "Output", "-" * 20)
|
||||||
|
st = time.time()
|
||||||
|
output = model.generate(
|
||||||
|
_input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict, streamer=streamer
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
if args.disable_streaming:
|
||||||
|
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||||
|
print(output_str)
|
||||||
|
print(f"Inference time: {end-st} s")
|
||||||
|
|
||||||
|
print("-" * 80)
|
||||||
|
print("done")
|
||||||
|
print("success shut down")
|
||||||
|
|
@ -180,6 +180,23 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
|
|
||||||
|
if hasattr(model, "config") and model.config.model_type == "glm":
|
||||||
|
# convert to llama structure
|
||||||
|
from .npu_models.glm_edge import convert_config, load_weights, convert_state_dict
|
||||||
|
import json
|
||||||
|
original_path = model.config._name_or_path
|
||||||
|
del model
|
||||||
|
|
||||||
|
with open(os.path.join(original_path, "config.json")) as f:
|
||||||
|
original_config = json.load(f)
|
||||||
|
config = convert_config(original_config)
|
||||||
|
original_state_dict = load_weights(original_path)
|
||||||
|
new_dict, _ = convert_state_dict(original_state_dict, config,
|
||||||
|
original_config.get("partial_rotary_factor", 1.0),
|
||||||
|
decouple_tied_embeddings=False)
|
||||||
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
|
model = cls.HF_Model.from_pretrained(original_path, config=config, state_dict=new_dict)
|
||||||
|
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
model.config.update({"optimize_model": optimize_model})
|
model.config.update({"optimize_model": optimize_model})
|
||||||
|
|
||||||
|
|
|
||||||
167
python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py
Normal file
167
python/llm/src/ipex_llm/transformers/npu_models/glm_edge.py
Normal file
|
|
@ -0,0 +1,167 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from tokenizers import processors
|
||||||
|
|
||||||
|
from transformers import LlamaConfig, PreTrainedTokenizerFast
|
||||||
|
from ipex_llm.transformers.utils import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
VIT_KEY = "vit_path" # FIXME: just made at random
|
||||||
|
VIT_FILE = "vit_adapter.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(input_dir: str):
|
||||||
|
safetensor_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir)
|
||||||
|
if x.endswith(".safetensors")]
|
||||||
|
bin_files = [os.path.join(input_dir, x) for x in os.listdir(input_dir) if x.endswith(".bin")]
|
||||||
|
|
||||||
|
all_weights = {}
|
||||||
|
|
||||||
|
if safetensor_files:
|
||||||
|
if len(safetensor_files) > 1:
|
||||||
|
safetensor_files = sorted(safetensor_files, key=lambda x: int(x.rsplit("-", 3)[1]))
|
||||||
|
for file in safetensor_files:
|
||||||
|
tensors = load_file(file)
|
||||||
|
all_weights.update(tensors)
|
||||||
|
return all_weights
|
||||||
|
|
||||||
|
elif bin_files:
|
||||||
|
if len(bin_files) > 1:
|
||||||
|
bin_files = sorted(bin_files, key=lambda x: int(x.rsplit("-", 3)[1]))
|
||||||
|
for file in bin_files:
|
||||||
|
tensors = torch.load(file, map_location="cpu")
|
||||||
|
all_weights.update(tensors)
|
||||||
|
return all_weights
|
||||||
|
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "No .safetensors or .bin files found in the specified directory.")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(original_state_dict: dict, config: LlamaConfig,
|
||||||
|
partial_rotary_factor: float, decouple_tied_embeddings=False):
|
||||||
|
hidden_size, num_heads = config.hidden_size, config.num_attention_heads
|
||||||
|
num_key_value_heads = config.num_key_value_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
rotary_dim = int(partial_rotary_factor * head_dim)
|
||||||
|
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
|
||||||
|
|
||||||
|
# permute for sliced rotary
|
||||||
|
def permute_weight(w, num_heads, rotary_dim):
|
||||||
|
w = w.view(num_heads, head_dim, hidden_size)
|
||||||
|
w, w_pass = w[:, :rotary_dim, :], w[:, rotary_dim:, :]
|
||||||
|
w = w.view(num_heads, rotary_dim // 2, 2, hidden_size).transpose(1, 2)\
|
||||||
|
.reshape(num_heads, rotary_dim, hidden_size)
|
||||||
|
return torch.cat([w, w_pass], dim=1).view(num_heads * head_dim, hidden_size)
|
||||||
|
|
||||||
|
def permute_bias(b, num_heads, rotary_dim):
|
||||||
|
b = b.view(num_heads, head_dim)
|
||||||
|
b, b_pass = b[:, :rotary_dim], b[:, rotary_dim:]
|
||||||
|
b = b.view(num_heads, rotary_dim // 2, 2).transpose(1, 2).reshape(num_heads, rotary_dim)
|
||||||
|
return torch.cat([b, b_pass], dim=1).view(num_heads * head_dim)
|
||||||
|
|
||||||
|
new_dict, vit_dict = {}, {}
|
||||||
|
param_count = 0
|
||||||
|
index_dict = {"weight_map": {}}
|
||||||
|
for key, value in original_state_dict.items():
|
||||||
|
if "model.vision" in key: # vit
|
||||||
|
vit_dict[key.replace("model.vision.", "")] = value.detach().clone()
|
||||||
|
elif "q_proj." in key:
|
||||||
|
if "weight" in key:
|
||||||
|
new_dict[key] = permute_weight(value, num_heads, rotary_dim)
|
||||||
|
elif config.attention_bias: # bias
|
||||||
|
new_dict[key] = permute_bias(value, num_heads, rotary_dim)
|
||||||
|
elif "k_proj." in key:
|
||||||
|
if "weight" in key:
|
||||||
|
new_dict[key] = permute_weight(value, num_key_value_heads, rotary_dim)
|
||||||
|
elif config.attention_bias: # bias
|
||||||
|
new_dict[key] = permute_bias(value, num_key_value_heads, rotary_dim)
|
||||||
|
elif "v_proj." in key:
|
||||||
|
if "bias" in key and not config.attention_bias:
|
||||||
|
continue
|
||||||
|
new_dict[key] = value
|
||||||
|
elif "o_proj." in key:
|
||||||
|
new_dict[key] = value
|
||||||
|
if config.attention_bias: # bias
|
||||||
|
new_dict[key.replace("weight", "bias")] = torch.zeros(hidden_size,
|
||||||
|
dtype=value.dtype)
|
||||||
|
elif "gate_up_proj." in key:
|
||||||
|
gate_proj, up_proj = value.chunk(2, dim=0)
|
||||||
|
new_dict[key.replace("gate_up_proj.", "gate_proj.")] = gate_proj
|
||||||
|
new_dict[key.replace("gate_up_proj.", "up_proj.")] = up_proj
|
||||||
|
else:
|
||||||
|
new_dict[key] = value
|
||||||
|
|
||||||
|
for layer_i in range(config.num_hidden_layers):
|
||||||
|
new_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq.clone()
|
||||||
|
|
||||||
|
if decouple_tied_embeddings:
|
||||||
|
new_dict["transformer.output_layer.weight"] = \
|
||||||
|
original_state_dict["model.embed_tokens.weight"].clone()
|
||||||
|
|
||||||
|
return new_dict, vit_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_config(original_config: dict, decouple_tied_embeddings=False):
|
||||||
|
similar_keys_to_keep = [
|
||||||
|
"num_attention_heads",
|
||||||
|
"hidden_size",
|
||||||
|
"intermediate_size",
|
||||||
|
"num_hidden_layers",
|
||||||
|
"rms_norm_eps",
|
||||||
|
"num_key_value_heads",
|
||||||
|
"vocab_size",
|
||||||
|
"partial_rotary_factor",
|
||||||
|
"rope_theta",
|
||||||
|
"max_position_embeddings",
|
||||||
|
"attention_bias",
|
||||||
|
"torch_dtype",
|
||||||
|
"tie_word_embeddings",
|
||||||
|
"bos_token_id",
|
||||||
|
"eos_token_id",
|
||||||
|
"pad_token_id",
|
||||||
|
"boi_token_id",
|
||||||
|
"eoi_token_id",
|
||||||
|
"vision_config",
|
||||||
|
]
|
||||||
|
new_config_kwargs = {k: v for k, v in original_config.items() if k in similar_keys_to_keep}
|
||||||
|
if getattr(original_config, "partial_rotary_factor", 1) < 1:
|
||||||
|
new_config_kwargs["rope_dim"] = original_config["head_dim"] * \
|
||||||
|
original_config["partial_rotary_factor"]
|
||||||
|
if decouple_tied_embeddings:
|
||||||
|
new_config_kwargs["tie_word_embeddings"] = False
|
||||||
|
if "vision_config" in original_config:
|
||||||
|
new_config_kwargs["vision_config"] = original_config["vision_config"]
|
||||||
|
new_config_kwargs[VIT_KEY] = VIT_FILE
|
||||||
|
if "bos_token_id" not in new_config_kwargs:
|
||||||
|
new_config_kwargs["bos_token_id"] = None
|
||||||
|
|
||||||
|
new_config = LlamaConfig(**new_config_kwargs)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_glm_tokenizer(input_dir):
|
||||||
|
fast_tok = PreTrainedTokenizerFast.from_pretrained(input_dir,
|
||||||
|
model_input_names=["input_ids",
|
||||||
|
"attention_mask"])
|
||||||
|
fast_tok._tokenizer.post_processor = processors.Sequence(
|
||||||
|
[processors.ByteLevel(trim_offsets=False)],
|
||||||
|
)
|
||||||
|
return fast_tok
|
||||||
|
|
@ -22,7 +22,10 @@ import transformers
|
||||||
|
|
||||||
trans_version = transformers.__version__
|
trans_version = transformers.__version__
|
||||||
|
|
||||||
if trans_version >= "4.45.0":
|
if trans_version >= "4.47.0":
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
elif trans_version >= "4.45.0":
|
||||||
from .benchmark_util_4_45 import BenchmarkWrapper
|
from .benchmark_util_4_45 import BenchmarkWrapper
|
||||||
elif trans_version >= "4.44.0":
|
elif trans_version >= "4.44.0":
|
||||||
from .benchmark_util_4_44 import BenchmarkWrapper
|
from .benchmark_util_4_44 import BenchmarkWrapper
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue