Add seq len check for llama softmax upcast to fp32 (#10629)
This commit is contained in:
parent
1aef3bc0ab
commit
c875b3c858
2 changed files with 21 additions and 7 deletions
|
|
@ -548,7 +548,6 @@ class LowBitLinear(nn.Linear):
|
||||||
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.device = get_xpu_device_type(self.weight.data)
|
self.device = get_xpu_device_type(self.weight.data)
|
||||||
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
|
|
||||||
self.low_memory_mode = \
|
self.low_memory_mode = \
|
||||||
self.low_memory_mode and \
|
self.low_memory_mode and \
|
||||||
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||||
|
|
|
||||||
|
|
@ -1001,8 +1001,13 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
if kv_seq_len >= 2048:
|
||||||
|
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
|
|
@ -1041,8 +1046,13 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
if kv_seq_len >= 2048:
|
||||||
|
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
@ -1326,8 +1336,13 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
f"but is {attention_mask.size()}")
|
f"but is {attention_mask.size()}")
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
if kv_seq_len >= 2048:
|
||||||
|
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(value.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue