diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index d3ef345f..cd7c60db 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -192,7 +192,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, cpu_embedding=False, prefix_name='', imatrix_data=None, embedding_qtype=None, - model_type=None): + model_type=None, torch_dtype=torch.float32): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding @@ -326,6 +326,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, _weight=module.weight.data, ) elif type(module) == nn.Embedding and embedding_qtype is not None: + if torch_dtype == "auto": + torch_dtype = torch.float32 q_embedding = LowBitEmbedding( num_embeddings=module.num_embeddings, embedding_dim=module.embedding_dim, @@ -336,6 +338,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, sparse=module.sparse, _weight=module.weight.data, qtype=embedding_qtype, + torch_dtype=torch_dtype ) device = module.weight.data.device # Copy the weights @@ -364,7 +367,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, - model_type=model_type + model_type=model_type, + torch_dtype=torch_dtype ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -571,7 +575,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, None, convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, - model_type=model_type + model_type=model_type, + torch_dtype=torch_dtype ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index 63cf3b81..4054418f 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -88,7 +88,8 @@ class LowBitEmbedding(torch.nn.Embedding): _weight: Optional[Tensor] = None, _freeze: bool = False, device=None, dtype=None, - qtype=None) -> None: + qtype=None, + torch_dtype=torch.float32) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, device, dtype) @@ -96,6 +97,7 @@ class LowBitEmbedding(torch.nn.Embedding): requires_grad=False, quantized=False, _shape=None, qtype=qtype) self.embedding_dim = embedding_dim + self.torch_dtype = torch_dtype def forward(self, x: Tensor): invalidInputError(x.device.type == "xpu", @@ -109,4 +111,4 @@ class LowBitEmbedding(torch.nn.Embedding): result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data, self.weight.qtype, self.embedding_dim) - return result + return result.to(self.torch_dtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 8a495991..f1b95a80 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -339,8 +339,8 @@ class _BaseAutoModelClass: invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " - f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, " - f"mixed_fp4 or mixed_fp8.") + f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, " + f"gguf_iq2_xs, mixed_fp4 or mixed_fp8.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try, @@ -510,7 +510,8 @@ class _BaseAutoModelClass: optimize_model = kwargs.pop("optimize_model", True) qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] - if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding: + if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and \ + not cpu_embedding: embedding_qtype = "q2_k" if embedding_qtype is not None: embedding_qtype = ggml_tensor_qtype[embedding_qtype] @@ -595,7 +596,7 @@ class _BaseAutoModelClass: model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, - embedding_qtype=embedding_qtype) + embedding_qtype=embedding_qtype, torch_dtype=torch_dtype) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index a6688082..38878b69 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 -from bigdl.llm.transformers.models.utils import use_flash_attention +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -242,6 +242,15 @@ def mistral_attention_forward( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, hidden_size) + elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index 42cfbfc2..fbf2341d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path -from bigdl.llm.transformers.models.utils import use_flash_attention +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import mlp_fusion_check @@ -272,6 +272,14 @@ def mixtral_attention_forward( value_states, is_causal=True) attn_weights = None + elif use_esimd_sdp(query_states.shape[2], key_states.shape[2], + self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None else: attn_weights = torch.matmul( query_states, diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 5f6f61a2..c238c595 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -310,7 +310,8 @@ def mlp_fusion_check(x, qtype, training): return False if x.device.type != 'xpu': return False - if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]: + if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"], + ggml_tensor_qtype["gguf_iq2_xxs"]]: return False if training or x.requires_grad: return False