LLM: support fp16 embedding & add mlp fusion for iq2_xxs (#10219)
* add fp16 embed * small fixes * fix style * fix style * fix comment
This commit is contained in:
parent
eeecd9fc08
commit
28513f3978
6 changed files with 38 additions and 12 deletions
|
|
@ -192,7 +192,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
current_key_name=None, convert_shape_only=False,
|
||||
cpu_embedding=False, prefix_name='',
|
||||
imatrix_data=None, embedding_qtype=None,
|
||||
model_type=None):
|
||||
model_type=None, torch_dtype=torch.float32):
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||
FP16Linear, BF16Linear
|
||||
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||
|
|
@ -326,6 +326,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
_weight=module.weight.data,
|
||||
)
|
||||
elif type(module) == nn.Embedding and embedding_qtype is not None:
|
||||
if torch_dtype == "auto":
|
||||
torch_dtype = torch.float32
|
||||
q_embedding = LowBitEmbedding(
|
||||
num_embeddings=module.num_embeddings,
|
||||
embedding_dim=module.embedding_dim,
|
||||
|
|
@ -336,6 +338,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
sparse=module.sparse,
|
||||
_weight=module.weight.data,
|
||||
qtype=embedding_qtype,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
device = module.weight.data.device
|
||||
# Copy the weights
|
||||
|
|
@ -364,7 +367,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
|
@ -571,7 +575,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
None, convert_shape_only, cpu_embedding,
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
if not has_been_replaced:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -88,7 +88,8 @@ class LowBitEmbedding(torch.nn.Embedding):
|
|||
_weight: Optional[Tensor] = None,
|
||||
_freeze: bool = False,
|
||||
device=None, dtype=None,
|
||||
qtype=None) -> None:
|
||||
qtype=None,
|
||||
torch_dtype=torch.float32) -> None:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||
max_norm, norm_type, scale_grad_by_freq, sparse,
|
||||
_weight, device, dtype)
|
||||
|
|
@ -96,6 +97,7 @@ class LowBitEmbedding(torch.nn.Embedding):
|
|||
requires_grad=False,
|
||||
quantized=False, _shape=None, qtype=qtype)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
invalidInputError(x.device.type == "xpu",
|
||||
|
|
@ -109,4 +111,4 @@ class LowBitEmbedding(torch.nn.Embedding):
|
|||
|
||||
result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data,
|
||||
self.weight.qtype, self.embedding_dim)
|
||||
return result
|
||||
return result.to(self.torch_dtype)
|
||||
|
|
|
|||
|
|
@ -339,8 +339,8 @@ class _BaseAutoModelClass:
|
|||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
||||
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
|
||||
f"mixed_fp4 or mixed_fp8.")
|
||||
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, "
|
||||
f"gguf_iq2_xs, mixed_fp4 or mixed_fp8.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
|
||||
# In case it needs a second try,
|
||||
|
|
@ -510,7 +510,8 @@ class _BaseAutoModelClass:
|
|||
optimize_model = kwargs.pop("optimize_model", True)
|
||||
|
||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||
if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding:
|
||||
if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and \
|
||||
not cpu_embedding:
|
||||
embedding_qtype = "q2_k"
|
||||
if embedding_qtype is not None:
|
||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||
|
|
@ -595,7 +596,7 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||
embedding_qtype=embedding_qtype)
|
||||
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
|
||||
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
|||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
is_enough_kv_cache_room_4_36
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
|
|
@ -242,6 +242,15 @@ def mistral_attention_forward(
|
|||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||
else:
|
||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||
key_states,
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
|
|||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||
apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
|
||||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
|
||||
|
||||
|
|
@ -272,6 +272,14 @@ def mixtral_attention_forward(
|
|||
value_states,
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
|
||||
self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states,
|
||||
|
|
|
|||
|
|
@ -310,7 +310,8 @@ def mlp_fusion_check(x, qtype, training):
|
|||
return False
|
||||
if x.device.type != 'xpu':
|
||||
return False
|
||||
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
|
||||
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"],
|
||||
ggml_tensor_qtype["gguf_iq2_xxs"]]:
|
||||
return False
|
||||
if training or x.requires_grad:
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue