Fix Mixtral GGUF Wrong Output Issue (#9930)
* Fix Mixtral GGUF Wrong Output Issue * fix style * fix style
This commit is contained in:
parent
453df868c9
commit
5184f400f9
3 changed files with 28 additions and 15 deletions
|
|
@ -28,7 +28,7 @@ conda create -n llm python=3.9 # recommend to use Python 3.9
|
|||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
|
||||
pip install transformers==4.34.0 # upgrade transformers
|
||||
pip install transformers==4.36.0 # upgrade transformers
|
||||
```
|
||||
### 2. Run
|
||||
After setting up the Python environment, you could run the example by following steps.
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ conda activate llm
|
|||
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
pip install transformers==4.34.0 # upgrade transformers
|
||||
pip install transformers==4.36.0 # upgrade transformers
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device as fill_model
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from tempfile import NamedTemporaryFile
|
||||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
||||
|
||||
|
|
@ -53,17 +53,30 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
|||
with init_empty_weights():
|
||||
model = MixtralForCausalLM(mixtral_config)
|
||||
|
||||
# define an operator function that passed to low-level gguf API
|
||||
def process_mixtral(name, tensor):
|
||||
# prepare module's name in transformers
|
||||
module_name = get_mixtral_module_name(name)
|
||||
# prepare module's weight in transformers
|
||||
if 'ffn_gate_inp' in name:
|
||||
# gguf weight needs to reshape for ffn_gate_inp
|
||||
fill_model(model,
|
||||
module_name,
|
||||
"cpu",
|
||||
tensor.reshape(num_local_experts, hidden_size),
|
||||
dtype=dtype)
|
||||
else:
|
||||
fill_model(model,
|
||||
tensor = tensor.reshape(num_local_experts, hidden_size)
|
||||
elif name.endswith("attn_q.weight"):
|
||||
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||
tensor = (tensor.reshape(n_head,
|
||||
head // n_head // 2,
|
||||
2,
|
||||
*hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(tensor.shape))
|
||||
elif name.endswith("attn_k.weight"):
|
||||
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||
tensor = (tensor.reshape(n_head_kv,
|
||||
head // n_head_kv // 2,
|
||||
2,
|
||||
*hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(tensor.shape))
|
||||
set_module_tensor_to_device(model,
|
||||
module_name,
|
||||
"cpu",
|
||||
tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue