fix llama2 (#10710)
This commit is contained in:
parent
e10040b7f1
commit
8f45e22072
2 changed files with 6 additions and 2 deletions
|
|
@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized(
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
|
||||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
|
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_weights = torch.matmul(query_states,
|
attn_weights = torch.matmul(query_states,
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
|
|
@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized(
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||||
|
|
||||||
def fp16_fusion_check(proj, x, training):
|
def fp16_fusion_check(proj, x, training):
|
||||||
# only use fp16 fusion on PVC inference
|
# only use fp16 fusion on PVC inference
|
||||||
|
if not hasattr(proj, "qtype"):
|
||||||
|
return False
|
||||||
if proj.qtype != ggml_tensor_qtype["fp16"]:
|
if proj.qtype != ggml_tensor_qtype["fp16"]:
|
||||||
return False
|
return False
|
||||||
if proj.weight_type != 2:
|
if proj.weight_type != 2:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue