LLM: support fp16 embedding & add mlp fusion for iq2_xxs (#10219)

* add fp16 embed

* small fixes

* fix style

* fix style

* fix comment
This commit is contained in:
Ruonan Wang 2024-02-23 17:26:24 +08:00 committed by GitHub
parent eeecd9fc08
commit 28513f3978
6 changed files with 38 additions and 12 deletions

View file

@ -192,7 +192,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_type=None):
model_type=None, torch_dtype=torch.float32):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
@ -326,6 +326,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
_weight=module.weight.data,
)
elif type(module) == nn.Embedding and embedding_qtype is not None:
if torch_dtype == "auto":
torch_dtype = torch.float32
q_embedding = LowBitEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
@ -336,6 +338,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
sparse=module.sparse,
_weight=module.weight.data,
qtype=embedding_qtype,
torch_dtype=torch_dtype
)
device = module.weight.data.device
# Copy the weights
@ -364,7 +367,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_type=model_type
model_type=model_type,
torch_dtype=torch_dtype
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
@ -571,7 +575,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
None, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_type=model_type
model_type=model_type,
torch_dtype=torch_dtype
)
if not has_been_replaced:
warnings.warn(

View file

@ -88,7 +88,8 @@ class LowBitEmbedding(torch.nn.Embedding):
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None,
qtype=None) -> None:
qtype=None,
torch_dtype=torch.float32) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq, sparse,
_weight, device, dtype)
@ -96,6 +97,7 @@ class LowBitEmbedding(torch.nn.Embedding):
requires_grad=False,
quantized=False, _shape=None, qtype=qtype)
self.embedding_dim = embedding_dim
self.torch_dtype = torch_dtype
def forward(self, x: Tensor):
invalidInputError(x.device.type == "xpu",
@ -109,4 +111,4 @@ class LowBitEmbedding(torch.nn.Embedding):
result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data,
self.weight.qtype, self.embedding_dim)
return result
return result.to(self.torch_dtype)

View file

@ -339,8 +339,8 @@ class _BaseAutoModelClass:
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
f"mixed_fp4 or mixed_fp8.")
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, "
f"gguf_iq2_xs, mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try,
@ -510,7 +510,8 @@ class _BaseAutoModelClass:
optimize_model = kwargs.pop("optimize_model", True)
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding:
if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and \
not cpu_embedding:
embedding_qtype = "q2_k"
if embedding_qtype is not None:
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
@ -595,7 +596,7 @@ class _BaseAutoModelClass:
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
embedding_qtype=embedding_qtype)
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]

View file

@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -242,6 +242,15 @@ def mistral_attention_forward(
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
else:
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
key_states,

View file

@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
@ -272,6 +272,14 @@ def mixtral_attention_forward(
value_states,
is_causal=True)
attn_weights = None
elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states,

View file

@ -310,7 +310,8 @@ def mlp_fusion_check(x, qtype, training):
return False
if x.device.type != 'xpu':
return False
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"],
ggml_tensor_qtype["gguf_iq2_xxs"]]:
return False
if training or x.requires_grad:
return False