fix llama2 (#10710)
This commit is contained in:
		
							parent
							
								
									e10040b7f1
								
							
						
					
					
						commit
						8f45e22072
					
				
					 2 changed files with 6 additions and 2 deletions
				
			
		| 
						 | 
					@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized(
 | 
				
			||||||
    kv_seq_len = key_states.shape[-2]
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
					    if len(past_key_value.key_cache) <= self.layer_idx:
 | 
				
			||||||
 | 
					        repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
        attn_weights = torch.matmul(query_states,
 | 
					        attn_weights = torch.matmul(query_states,
 | 
				
			||||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					                                    repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
					        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
				
			||||||
            invalidInputError(
 | 
					            invalidInputError(
 | 
				
			||||||
| 
						 | 
					@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized(
 | 
				
			||||||
            # upcast attention to fp32
 | 
					            # upcast attention to fp32
 | 
				
			||||||
            attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					            attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
                                                 dtype=torch.float32).to(query_states.dtype)
 | 
					                                                 dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
					        attn_output = torch.matmul(attn_weights, repeated_value_states)
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            cache_kwargs = None
 | 
					            cache_kwargs = None
 | 
				
			||||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def fp16_fusion_check(proj, x, training):
 | 
					def fp16_fusion_check(proj, x, training):
 | 
				
			||||||
    # only use fp16 fusion on PVC inference
 | 
					    # only use fp16 fusion on PVC inference
 | 
				
			||||||
 | 
					    if not hasattr(proj, "qtype"):
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
    if proj.qtype != ggml_tensor_qtype["fp16"]:
 | 
					    if proj.qtype != ggml_tensor_qtype["fp16"]:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
    if proj.weight_type != 2:
 | 
					    if proj.weight_type != 2:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue