[NPU] update save-load API usage (#12473)

This commit is contained in:
binbin Deng 2024-12-03 09:46:15 +08:00 committed by GitHub
parent 26adb82ee3
commit ab01753b1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 166 additions and 188 deletions

View file

@ -634,7 +634,7 @@ def transformers_int4_npu_win(repo_id,
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=optimize_model,
trust_remote_code=True, use_cache=True, max_context_len=max_context_len, max_prompt_len=int(in_out_len[0]),
quantization_group_size=npu_group_size, transpose_value_cache=transpose_value_cache,
attn_implementation="eager", torch_dtype=torch.float16).eval()
save_directory=save_directory, attn_implementation="eager", torch_dtype=torch.float16).eval()
model = model.llm
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
@ -702,6 +702,7 @@ def transformers_int4_npu_pipeline_win(repo_id,
in_out_len = in_out_pairs[0].split("-")
max_context_len = max(int(in_out_len[0]) + int(in_out_len[1]), 1024)
mixed_precision = True if npu_group_size == 0 else False
save_directory = "./save_converted_model_dir"
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
@ -709,7 +710,8 @@ def transformers_int4_npu_pipeline_win(repo_id,
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, pipeline=True, torch_dtype=torch.float16,
optimize_model=optimize_model, max_context_len=max_context_len, max_prompt_len=int(in_out_len[0]),
quantization_group_size=npu_group_size, transpose_value_cache=transpose_value_cache,
use_cache=True, attn_implementation="eager", mixed_precision=mixed_precision).eval()
use_cache=True, attn_implementation="eager", mixed_precision=mixed_precision,
save_directory=save_directory).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter()

View file

@ -47,45 +47,45 @@ In the example [generate.py](./generate.py), we show a basic use case for a Llam
```cmd
:: to run Llama-2-7b-chat-hf
python llama2.py
python llama2.py --repo-id-or-model-path "meta-llama/Llama-2-7b-chat-hf" --save-directory <converted_model_path>
:: to run Meta-Llama-3-8B-Instruct
python llama3.py
python llama3.py --repo-id-or-model-path "meta-llama/Meta-Llama-3-8B-Instruct" --save-directory <converted_model_path>
:: to run Llama-3.2-1B-Instruct
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-1B-Instruct"
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-1B-Instruct" --save-directory <converted_model_path>
:: to run Llama-3.2-3B-Instruct
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-3B-Instruct"
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-3B-Instruct" --save-directory <converted_model_path>
:: to run Qwen2.5-7B-Instruct
python qwen.py
python qwen.py --repo-id-or-model-path "Qwen/Qwen2.5-7B-Instruct" --save-directory <converted_model_path>
:: to run Qwen2-1.5B-Instruct
python qwen.py --repo-id-or-model-path "Qwen/Qwen2-1.5B-Instruct" --low_bit "sym_int8"
python qwen.py --repo-id-or-model-path "Qwen/Qwen2-1.5B-Instruct" --low-bit sym_int8 --save-directory <converted_model_path>
:: to run Qwen2.5-3B-Instruct
python qwen.py --repo-id-or-model-path "Qwen/Qwen2.5-3B-Instruct" --low_bit "sym_int8"
python qwen.py --repo-id-or-model-path "Qwen/Qwen2.5-3B-Instruct" --low-bit sym_int8 --save-directory <converted_model_path>
:: to run Baichuan2-7B-Chat
python baichuan2.py
python baichuan2.py --repo-id-or-model-path "baichuan-inc/Baichuan2-7B-Chat" --save-directory <converted_model_path>
:: to run MiniCPM-1B-sft-bf16
python minicpm.py
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-1B-sft-bf16" --save-directory <converted_model_path>
:: to run MiniCPM-2B-sft-bf16
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16"
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" --save-directory <converted_model_path>
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the model (e.g. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder.
- `--lowbit-path LOWBIT_MODEL_PATH`: argument defining the path to save/load lowbit version of the model. If it is an empty string, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded. If it is an existing path, the lowbit model in `LOWBIT_MODEL_PATH` will be loaded. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, and the converted lowbit version will be saved into `LOWBIT_MODEL_PATH`. It is default to be `''`, i.e. an empty string.
- `--prompt PROMPT`: argument defining the prompt to be infered. It is default to be `What is AI?`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
- `--disable-streaming`: Disable streaming mode of generation.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, otherwise the lowbit model in `SAVE_DIRECTORY` will be loaded.
### Sample Output of Streaming Mode
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)

View file

@ -49,12 +49,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Baichuan2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -63,11 +57,17 @@ if __name__ == "__main__":
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
@ -77,10 +77,11 @@ if __name__ == "__main__":
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True)
trust_remote_code=True,
save_directory=args.save_directory)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
@ -92,9 +93,6 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:

View file

@ -49,12 +49,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Llama2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -63,11 +57,17 @@ if __name__ == "__main__":
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
@ -76,10 +76,11 @@ if __name__ == "__main__":
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
@ -90,9 +91,6 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:

View file

@ -55,12 +55,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Llama3 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -69,11 +63,17 @@ if __name__ == "__main__":
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.float16,
optimize_model=True,
@ -82,10 +82,11 @@ if __name__ == "__main__":
max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
@ -96,9 +97,6 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:

View file

@ -36,12 +36,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the MiniCPM model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -50,11 +44,17 @@ if __name__ == "__main__":
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
@ -64,10 +64,11 @@ if __name__ == "__main__":
attn_implementation="eager",
quantization_group_size=args.quantization_group_size,
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True)
trust_remote_code=True,
save_directory=args.save_directory)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
@ -79,9 +80,6 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:

View file

@ -36,27 +36,27 @@ if __name__ == "__main__":
help="The huggingface repo id for the Qwen model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="AI是什么?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument('--low_bit', type=str, default="sym_int4",
parser.add_argument('--low-bit', type=str, default="sym_int4",
help='Low bit precision to quantize the model')
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
pipeline=True,
@ -68,10 +68,11 @@ if __name__ == "__main__":
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
mixed_precision=True,
trust_remote_code=True)
trust_remote_code=True,
save_directory=args.save_directory)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
max_context_len=args.max_context_len,
@ -81,9 +82,6 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
if args.disable_streaming:
streamer = None
else:

View file

@ -99,45 +99,44 @@ The examples below show how to run the **_optimized HuggingFace model implementa
### Run
```cmd
:: to run Llama-2-7b-chat-hf
python llama2.py --repo-id-or-model-path meta-llama/Llama-2-7b-chat-hf --save-directory <converted_model_path>
python llama2.py --repo-id-or-model-path "meta-llama/Llama-2-7b-chat-hf" --save-directory <converted_model_path>
:: to run Meta-Llama-3-8B-Instruct
python llama3.py --repo-id-or-model-path meta-llama/Meta-Llama-3-8B-Instruct --save-directory <converted_model_path>
python llama3.py --repo-id-or-model-path "meta-llama/Meta-Llama-3-8B-Instruct" --save-directory <converted_model_path>
:: to run Llama-3.2-1B-Instruct
python llama3.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --save-directory <converted_model_path>
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-1B-Instruct" --save-directory <converted_model_path>
:: to run Llama-3.2-3B-Instruct
python llama3.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --save-directory <converted_model_path>
python llama3.py --repo-id-or-model-path "meta-llama/Llama-3.2-3B-Instruct" --save-directory <converted_model_path>
:: to run Qwen2-1.5B-Instruct
python qwen.py --repo-id-or-model-path Qwen/Qwen2-1.5B-Instruct --low_bit sym_int8 --save-directory <converted_model_path>
python qwen.py --repo-id-or-model-path "Qwen/Qwen2-1.5B-Instruct" --low-bit sym_int8 --save-directory <converted_model_path>
:: to run Qwen2.5-3B-Instruct
python qwen.py --repo-id-or-model-path Qwen/Qwen2.5-3B-Instruct --low_bit sym_int8 --save-directory <converted_model_path>
python qwen.py --repo-id-or-model-path "Qwen/Qwen2.5-3B-Instruct" --low-bit sym_int8 --save-directory <converted_model_path>
:: to run Qwen2.5-7B-Instruct
python qwen.py --repo-id-or-model-path Qwen/Qwen2.5-7B-Instruct --save-directory <converted_model_path>
python qwen.py --repo-id-or-model-path "Qwen/Qwen2.5-7B-Instruct" --save-directory <converted_model_path>
:: to run MiniCPM-1B-sft-bf16
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-1B-sft-bf16 --save-directory <converted_model_path>
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-1B-sft-bf16" --save-directory <converted_model_path>
:: to run MiniCPM-2B-sft-bf16
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --save-directory <converted_model_path>
python minicpm.py --repo-id-or-model-path "openbmb/MiniCPM-2B-sft-bf16" --save-directory <converted_model_path>
:: to run Baichuan2-7B-Chat
python baichuan2.py
python baichuan2.py --repo-id-or-model-path "baichuan-inc/Baichuan2-7B-Chat" --save-directory <converted_model_path>
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--lowbit-path LOWBIT_MODEL_PATH`: argument defining the path to save/load lowbit version of the model. If it is an empty string, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded. If it is an existing path, the lowbit model in `LOWBIT_MODEL_PATH` will be loaded. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, and the converted lowbit version will be saved into `LOWBIT_MODEL_PATH`. It is default to be `''`, i.e. an empty string.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `What is AI?`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, otherwise the lowbit model in `SAVE_DIRECTORY` will be loaded.
### Troubleshooting

View file

@ -50,57 +50,49 @@ if __name__ == "__main__":
help="The huggingface repo id for the Baichuan2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="eager",
load_in_low_bit="sym_int4",
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory
)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
DEFAULT_SYSTEM_PROMPT = """\
"""

View file

@ -50,12 +50,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Llama2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -66,13 +60,13 @@ if __name__ == "__main__":
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, program will raise error.",
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
@ -87,22 +81,17 @@ if __name__ == "__main__":
)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
DEFAULT_SYSTEM_PROMPT = """\
"""

View file

@ -51,12 +51,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Llama3 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -67,13 +61,13 @@ if __name__ == "__main__":
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, program will raise error.",
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
@ -88,22 +82,17 @@ if __name__ == "__main__":
)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
DEFAULT_SYSTEM_PROMPT = """\
"""

View file

@ -37,12 +37,6 @@ if __name__ == "__main__":
help="The huggingface repo id for the Llama2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
@ -53,12 +47,12 @@ if __name__ == "__main__":
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, program will raise error.",
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
@ -73,22 +67,17 @@ if __name__ == "__main__":
)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
print("-" * 80)
print("done")
with torch.inference_mode():

View file

@ -37,32 +37,26 @@ if __name__ == "__main__":
help="The huggingface repo id for the Qwen2 or Qwen2.5 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="AI是什么?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument('--low_bit', type=str, default="sym_int4",
parser.add_argument('--low-bit', type=str, default="sym_int4",
help='Load in low bit to use')
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, program will raise error.",
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
if not os.path.exists(args.save_directory):
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
@ -79,22 +73,17 @@ if __name__ == "__main__":
)
else:
model = AutoModelForCausalLM.load_low_bit(
args.lowbit_path,
args.save_directory,
attn_implementation="eager",
torch_dtype=torch.float16,
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)
print("-" * 80)
print("done")
messages = [{"role": "system", "content": "You are a helpful assistant."},

View file

@ -103,10 +103,10 @@ The examples below show how to run the **_optimized HuggingFace & FunASR model i
### 4.1 Run MiniCPM-Llama3-V-2_5 & MiniCPM-V-2_6
```bash
# to run MiniCPM-Llama3-V-2_5
python minicpm-llama3-v2.5.py
python minicpm-llama3-v2.5.py --save-directory <converted_model_path>
# to run MiniCPM-V-2_6
python minicpm_v_2_6.py
python minicpm_v_2_6.py --save-directory <converted_model_path>
```
Arguments info:
@ -117,6 +117,7 @@ Arguments info:
- `--max-output-len MAX_OUTPUT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, otherwise the lowbit model in `SAVE_DIRECTORY` will be loaded.
#### Sample Output
##### [openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)
@ -134,12 +135,13 @@ The image features a young child holding and showing off a white teddy bear wear
### 4.2 Run Speech_Paraformer-Large
```bash
# to run Speech_Paraformer-Large
python speech_paraformer-large.py
python speech_paraformer-large.py --save-directory <converted_model_path>
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the asr repo id for the model (i.e. `iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch`) to be downloaded, or the path to the asr checkpoint folder.
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, otherwise the lowbit model in `SAVE_DIRECTORY` will be loaded.
#### Sample Output
##### [iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
@ -157,11 +159,12 @@ rtf_avg: 0.232: 100%|███████████████████
### 4.3 Run Bce-Embedding-Base-V1
```bash
# to run Bce-Embedding-Base-V1
python bce-embedding.py
python bce-embedding.py --save-directory <converted_model_path>
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the asr repo id for the model (i.e. `maidalun1020/bce-embedding-base_v1`) to be downloaded, or the path to the asr checkpoint folder.
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, otherwise the lowbit model in `SAVE_DIRECTORY` will be loaded.
#### Sample Output
##### [maidalun1020/bce-embedding-base_v1](https://huggingface.co/maidalun1020/bce-embedding-base_v1) |

View file

@ -35,19 +35,17 @@ if __name__ == "__main__":
help="The huggingface repo id for the bce-embedding model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument("--lowbit-path", type=str,
default="",
help="The path to the lowbit model folder, leave blank if you do not want to save. \
If path not exists, lowbit model will be saved there. \
Else, lowbit model will be loaded.",
)
parser.add_argument('--prompt', type=str, default="'sentence_0', 'sentence_1'",
help='Prompt to infer')
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
@ -60,9 +58,8 @@ if __name__ == "__main__":
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory
)
# list of sentences

View file

@ -48,8 +48,12 @@ if __name__ == "__main__":
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
@ -63,9 +67,8 @@ if __name__ == "__main__":
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

View file

@ -39,8 +39,12 @@ if __name__ == '__main__':
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=None)
parser.add_argument("--inter-pp", type=int, default=None)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
@ -54,9 +58,8 @@ if __name__ == '__main__':
optimize_model=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
save_directory=args.save_directory
)
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)

View file

@ -35,8 +35,12 @@ if __name__ == "__main__":
)
parser.add_argument('--load_in_low_bit', type=str, default="sym_int8",
help='Load in low bit to use')
parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2)
parser.add_argument("--save-directory", type=str,
required=True,
help="The path of folder to save converted model, "
"If path not exists, lowbit model will be saved there. "
"Else, lowbit model will be loaded.",
)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
@ -47,8 +51,7 @@ if __name__ == "__main__":
load_in_low_bit=args.load_in_low_bit,
low_cpu_mem_usage=True,
optimize_model=True,
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
save_directory=args.save_directory
)
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",

View file

@ -45,6 +45,9 @@ def ignore_argument(kwargs: dict, key: "str"):
def save_low_bit(self, model_dir: str, *args, **kwargs):
if hasattr(self, "save_directory"):
warnings.warn(f"Model is already saved at {self.save_directory}")
return 1
origin_device = self.device
kwargs["safe_serialization"] = False
self.save_pretrained(model_dir, *args, **kwargs)
@ -255,6 +258,9 @@ class _BaseAutoModelClass:
save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_data = kwargs.pop('imatrix_data', None)
invalidInputError(save_directory is not None,
"Please provide the path to save converted model "
"through `save_directory`.")
if hasattr(model, "llm"):
llm = model.llm
@ -312,6 +318,8 @@ class _BaseAutoModelClass:
save_directory=save_directory,
fuse_layers=fuse_layers)
model.save_low_bit = types.MethodType(save_low_bit, model)
model.save_low_bit(save_directory)
logger.info(f"Converted model has already saved to {save_directory}.")
return model
@classmethod
@ -398,6 +406,7 @@ class _BaseAutoModelClass:
mixed_precision = config_dict.pop("mixed_precision", False)
quantization_group_size = config_dict.pop("group_size", 0)
optimize_model = config_dict.pop("optimize_model", False)
enable_cpp_backend = "weight_idx" in config_dict
invalidInputError(
qtype,
@ -412,6 +421,26 @@ class _BaseAutoModelClass:
f" expected: sym_int8_rtn, sym_int4_rtn. "
)
if enable_cpp_backend:
from .npu_models.npu_llm_cpp import load_model_from_file
from .npu_models.convert import generate
dummy_model = torch.nn.Module()
try:
model_ptr = load_model_from_file(pretrained_model_name_or_path)
dummy_model.config = PretrainedConfig.from_dict(config_dict)
dummy_model.model_ptr = model_ptr
dummy_model.save_directory = pretrained_model_name_or_path
dummy_model.kv_len = config_dict['kv_len']
dummy_model.vocab_size = config_dict['vocab_size']
except:
invalidInputError(False,
"False to InitLLMPipeline.")
dummy_model.eval()
# patch generate function
import types
dummy_model.generate = types.MethodType(generate, dummy_model)
return dummy_model
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(

View file

@ -389,6 +389,7 @@ def optimize_llm_single_process(
model_ptr = load_model_from_file(save_directory)
model.kv_len = kv_len
model.model_ptr = model_ptr
model.save_directory = save_directory
model.vocab_size = model.config.vocab_size
except:
invalidInputError(False,