[NPU] Add minicpm-2b support for npu multi-processing (#11949)

* add minicpm-2b support

* update example for minicpm-2b

* add LNL NPU driver requirement in readme
This commit is contained in:
SONG Ge 2024-08-28 18:08:49 +08:00 committed by GitHub
parent 0fbb10259a
commit 5ca7390082
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 2 deletions

View file

@ -82,6 +82,7 @@ The example below shows how to run the **_optimized model implementations_** on
- [Llama3-8B](./llama.py)
- [Qwen2-1.5B](./qwen2.py)
- [MiniCPM-1B](./minicpm.py)
- [MiniCPM-2B](./minicpm.py)
- [Baichuan2-7B](./baichuan2.py)
```bash
@ -97,6 +98,9 @@ python qwen2.py
# to run MiniCPM-1B-sft-bf16
python minicpm.py
# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16
# to run Baichuan2-7B-Chat
python baichuan2.py
```
@ -124,6 +128,9 @@ python qwen2.py --disable-transpose-value-cache
# to run MiniCPM-1B-sft-bf16
python minicpm.py --disable-transpose-value-cache
# to run MiniCPM-2B-sft-bf16 (LNL driver version: 32.0.101.2715)
python minicpm.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --disable-transpose-value-cache
```

View file

@ -151,18 +151,25 @@ def optimize_llm(
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if model.config.num_hidden_layers == 52:
# for minicpm-1b
transpose_cache = transpose_value_cache
elif model.config.num_hidden_layers == 40:
# for minicpm-2b
transpose_cache = False
decode_runner = DecodeRunner(
model,
max_seq_len=max_output_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
transpose_value_cache=transpose_cache,
)
prefill_runner = PrefillRunner(
model,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
transpose_value_cache=transpose_cache,
)
minicpm_model_forward = gen_minicpm_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner