[NPU] Llama2 prefill use ov sdp (#12310)
* prefill use sdp * add param * update * fix style * fix style * meet comments
This commit is contained in:
		
							parent
							
								
									eda764909c
								
							
						
					
					
						commit
						05c5d0267a
					
				
					 2 changed files with 46 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -110,13 +110,20 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # define input, the order self.parameter matters
 | 
			
		||||
        input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
 | 
			
		||||
 | 
			
		||||
        # llama2 use ov sdp, other models need to test
 | 
			
		||||
        use_prefill_sdp = self.intermediate_size == 11008
 | 
			
		||||
 | 
			
		||||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
            if use_prefill_sdp:
 | 
			
		||||
                attention_mask = None
 | 
			
		||||
            else:
 | 
			
		||||
                attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
 | 
			
		||||
                                                       self.seq_len),
 | 
			
		||||
                                                      dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +184,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
			
		||||
                past_key=past_keys[i],
 | 
			
		||||
                past_value=past_values[i],
 | 
			
		||||
                use_prefill_sdp=use_prefill_sdp,
 | 
			
		||||
            )
 | 
			
		||||
            curr_key_values.append((new_key_states, new_value_states))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -202,6 +210,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        post_attention_layernorm_weight,
 | 
			
		||||
        past_key=None,
 | 
			
		||||
        past_value=None,
 | 
			
		||||
        use_prefill_sdp=False,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
| 
						 | 
				
			
			@ -220,6 +229,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            num_key_value_heads=self.num_key_value_heads,
 | 
			
		||||
            head_dim=self.head_dim,
 | 
			
		||||
            seq_len=self.seq_len,
 | 
			
		||||
            use_prefill_sdp=use_prefill_sdp,
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states = self.eltwise_add(residual, attn_output)
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
| 
						 | 
				
			
			@ -427,6 +437,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
        self.use_prefill_sdp = intermediate_size == 11008
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -451,9 +462,13 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        seq_len = hidden_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                  attention_mask.to(torch.int64),
 | 
			
		||||
                  position_ids.to(torch.int64))
 | 
			
		||||
        if self.use_prefill_sdp:
 | 
			
		||||
            inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                      position_ids.to(torch.int64))
 | 
			
		||||
        else:
 | 
			
		||||
            inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                      attention_mask.to(torch.int64),
 | 
			
		||||
                      position_ids.to(torch.int64))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -135,10 +135,10 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
                  seq_len,
 | 
			
		||||
                  q_bias=None,
 | 
			
		||||
                  k_bias=None,
 | 
			
		||||
                  v_bias=None):
 | 
			
		||||
                  v_bias=None,
 | 
			
		||||
                  use_prefill_sdp=False):
 | 
			
		||||
        hidden_size = num_heads * head_dim
 | 
			
		||||
        num_key_value_groups = num_heads // num_key_value_heads
 | 
			
		||||
        groupsize = hidden_size // self.n_splits_linear
 | 
			
		||||
        if self.n_splits_linear == 1:
 | 
			
		||||
            query_states = self.linear(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -200,8 +200,13 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
 | 
			
		||||
        query_states = self.transpose(query_states, [0, 2, 1, 3])
 | 
			
		||||
        key_states = self.transpose(key_states, [0, 2, 1, 3])
 | 
			
		||||
        use_ov_sdp = (mode == "prefill") and use_prefill_sdp
 | 
			
		||||
        if self.transpose_value:
 | 
			
		||||
            value_states = self.transpose(value_states, [0, 2, 3, 1])
 | 
			
		||||
            new_value_states = self.transpose(value_states, [0, 2, 3, 1])
 | 
			
		||||
            if use_ov_sdp:
 | 
			
		||||
                value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
			
		||||
            else:
 | 
			
		||||
                value_states = new_value_states
 | 
			
		||||
        else:
 | 
			
		||||
            value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +221,6 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
            head_dim=head_dim,
 | 
			
		||||
        )
 | 
			
		||||
        new_key_states = key_states
 | 
			
		||||
        new_value_states = value_states
 | 
			
		||||
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            key_states = self.concat(past_key, key_states, axis=-2)
 | 
			
		||||
| 
						 | 
				
			
			@ -238,16 +242,24 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
                                      num_key_value_heads=num_key_value_heads,
 | 
			
		||||
                                      kv_seq_len=kv_seq_len,
 | 
			
		||||
                                      head_dim=head_dim,
 | 
			
		||||
                                      transpose=self.transpose_value)
 | 
			
		||||
        attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
            math.sqrt(head_dim)
 | 
			
		||||
        )
 | 
			
		||||
        attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
        attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
        attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
        attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
        attn_weight = self.convert_to_fp16(attn_weight)
 | 
			
		||||
        attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)
 | 
			
		||||
                                      transpose=(self.transpose_value and (not use_ov_sdp)))
 | 
			
		||||
        if use_ov_sdp:
 | 
			
		||||
            value_states = self.convert_to_fp32(value_states)
 | 
			
		||||
            key_states = self.convert_to_fp32(key_states)
 | 
			
		||||
            query_states = self.convert_to_fp32(query_states)
 | 
			
		||||
            attn_output = self.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, None, True)
 | 
			
		||||
            attn_output = self.convert_to_fp16(attn_output)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
                math.sqrt(head_dim)
 | 
			
		||||
            )
 | 
			
		||||
            attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
            attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
            attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
            attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
            attn_weight = self.convert_to_fp16(attn_weight)
 | 
			
		||||
            attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)
 | 
			
		||||
 | 
			
		||||
        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
			
		||||
        attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue