Add seq len check for llama softmax upcast to fp32 (#10629)
This commit is contained in:
parent
1aef3bc0ab
commit
c875b3c858
2 changed files with 21 additions and 7 deletions
|
|
@ -548,7 +548,6 @@ class LowBitLinear(nn.Linear):
|
|||
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
||||
if self.device is None:
|
||||
self.device = get_xpu_device_type(self.weight.data)
|
||||
# TODO: may remove IPEX_LLM_LOW_MEM here, probably not necessary
|
||||
self.low_memory_mode = \
|
||||
self.low_memory_mode and \
|
||||
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
||||
|
|
|
|||
|
|
@ -1001,8 +1001,13 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
if use_cache:
|
||||
cache_kwargs = None
|
||||
|
|
@ -1041,8 +1046,13 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
|
|
@ -1326,8 +1336,13 @@ def native_sdp(query, key, value, attention_mask,
|
|||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue