Fix Mixtral GGUF Wrong Output Issue (#9930)
* Fix Mixtral GGUF Wrong Output Issue * fix style * fix style
This commit is contained in:
parent
453df868c9
commit
5184f400f9
3 changed files with 28 additions and 15 deletions
|
|
@ -28,7 +28,7 @@ conda create -n llm python=3.9 # recommend to use Python 3.9
|
||||||
conda activate llm
|
conda activate llm
|
||||||
|
|
||||||
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
|
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
|
||||||
pip install transformers==4.34.0 # upgrade transformers
|
pip install transformers==4.36.0 # upgrade transformers
|
||||||
```
|
```
|
||||||
### 2. Run
|
### 2. Run
|
||||||
After setting up the Python environment, you could run the example by following steps.
|
After setting up the Python environment, you could run the example by following steps.
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ conda activate llm
|
||||||
|
|
||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||||
pip install transformers==4.34.0 # upgrade transformers
|
pip install transformers==4.36.0 # upgrade transformers
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Configures OneAPI environment variables
|
### 2. Configures OneAPI environment variables
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import set_module_tensor_to_device as fill_model
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
|
@ -53,21 +53,34 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = MixtralForCausalLM(mixtral_config)
|
model = MixtralForCausalLM(mixtral_config)
|
||||||
|
|
||||||
|
# define an operator function that passed to low-level gguf API
|
||||||
def process_mixtral(name, tensor):
|
def process_mixtral(name, tensor):
|
||||||
|
# prepare module's name in transformers
|
||||||
module_name = get_mixtral_module_name(name)
|
module_name = get_mixtral_module_name(name)
|
||||||
|
# prepare module's weight in transformers
|
||||||
if 'ffn_gate_inp' in name:
|
if 'ffn_gate_inp' in name:
|
||||||
# gguf weight needs to reshape for ffn_gate_inp
|
tensor = tensor.reshape(num_local_experts, hidden_size)
|
||||||
fill_model(model,
|
elif name.endswith("attn_q.weight"):
|
||||||
module_name,
|
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||||
"cpu",
|
tensor = (tensor.reshape(n_head,
|
||||||
tensor.reshape(num_local_experts, hidden_size),
|
head // n_head // 2,
|
||||||
dtype=dtype)
|
2,
|
||||||
else:
|
*hd_size)
|
||||||
fill_model(model,
|
.swapaxes(1, 2)
|
||||||
module_name,
|
.reshape(tensor.shape))
|
||||||
"cpu",
|
elif name.endswith("attn_k.weight"):
|
||||||
tensor,
|
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||||
dtype=dtype)
|
tensor = (tensor.reshape(n_head_kv,
|
||||||
|
head // n_head_kv // 2,
|
||||||
|
2,
|
||||||
|
*hd_size)
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(tensor.shape))
|
||||||
|
set_module_tensor_to_device(model,
|
||||||
|
module_name,
|
||||||
|
"cpu",
|
||||||
|
tensor,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
tensor_loader = loader.tensor_loader
|
tensor_loader = loader.tensor_loader
|
||||||
tensor_loader.load_while_process(process_mixtral)
|
tensor_loader.load_while_process(process_mixtral)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue