diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 7c6697f3..96f57167 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -548,7 +548,6 @@ class LowBitLinear(nn.Linear): # on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time. if self.device is None: self.device = get_xpu_device_type(self.weight.data) - # TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary self.low_memory_mode = \ self.low_memory_mode and \ (self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1") diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index a4044e1a..ec554874 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1001,8 +1001,13 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # at inference time, for memory considerations, may not need to upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if kv_seq_len >= 2048: + # for memory considerations, do not upcast attention to fp32 for long sequences + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if use_cache: cache_kwargs = None @@ -1041,8 +1046,13 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # at inference time, for memory considerations, may not need to upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if kv_seq_len >= 2048: + # for memory considerations, do not upcast attention to fp32 for long sequences + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) else: import linear_q4_0 @@ -1326,8 +1336,13 @@ def native_sdp(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - # at inference time, for memory considerations, may not need to upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if kv_seq_len >= 2048: + # for memory considerations, do not upcast attention to fp32 for long sequences + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value.dtype) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights