Add seq len check for llama softmax upcast to fp32 (#10629)

This commit is contained in:
Kai Huang 2024-04-03 12:05:13 +08:00 committed by GitHub
parent 1aef3bc0ab
commit c875b3c858
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 7 deletions

View file

@ -548,7 +548,6 @@ class LowBitLinear(nn.Linear):
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
if self.device is None:
self.device = get_xpu_device_type(self.weight.data)
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
self.low_memory_mode = \
self.low_memory_mode and \
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")

View file

@ -1001,8 +1001,13 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if use_cache:
cache_kwargs = None
@ -1041,8 +1046,13 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
@ -1326,8 +1336,13 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value.dtype)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights