Revert prefill logic of qwen2-7b (#11992)
This commit is contained in:
		
							parent
							
								
									659d15defc
								
							
						
					
					
						commit
						01099f08ee
					
				
					 1 changed files with 44 additions and 123 deletions
				
			
		| 
						 | 
					@ -801,13 +801,13 @@ def run_prefill(
 | 
				
			||||||
    input_layer_norm_weights = []
 | 
					    input_layer_norm_weights = []
 | 
				
			||||||
    post_attn_layernorm_weights = []
 | 
					    post_attn_layernorm_weights = []
 | 
				
			||||||
    layer_indexs = range(layer_start, layer_end)
 | 
					    layer_indexs = range(layer_start, layer_end)
 | 
				
			||||||
    if model.config.intermediate_size == 8960:
 | 
					    for layer_idx in layer_indexs:
 | 
				
			||||||
        # for qwen2-1.5b
 | 
					        curr_layer = model.model.layers[layer_idx]
 | 
				
			||||||
        for layer_idx in layer_indexs:
 | 
					        attn_layer = curr_layer.self_attn
 | 
				
			||||||
            curr_layer = model.model.layers[layer_idx]
 | 
					        mlp_layer = curr_layer.mlp
 | 
				
			||||||
            attn_layer = curr_layer.self_attn
 | 
					 | 
				
			||||||
            mlp_layer = curr_layer.mlp
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if model.config.intermediate_size == 8960:
 | 
				
			||||||
 | 
					            # for qwen2-1.5b
 | 
				
			||||||
            weights = [
 | 
					            weights = [
 | 
				
			||||||
                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
					                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
				
			||||||
                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
					                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
				
			||||||
| 
						 | 
					@ -817,53 +817,52 @@ def run_prefill(
 | 
				
			||||||
                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
					                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
				
			||||||
                (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
					                (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
 | 
					        elif model.config.intermediate_size == 18944:
 | 
				
			||||||
 | 
					            # for qwen2-7b
 | 
				
			||||||
 | 
					            weights = [
 | 
				
			||||||
 | 
					                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
				
			||||||
 | 
					                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
				
			||||||
 | 
					                (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
				
			||||||
 | 
					                (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
				
			||||||
 | 
					                (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
				
			||||||
 | 
					                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
				
			||||||
 | 
					                (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
 | 
				
			||||||
 | 
					                (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
					        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
				
			||||||
            layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
					        layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            new_decoderlayer = FusedQwenLowBitDecoderlayer(
 | 
					        new_decoderlayer = FusedQwenLowBitDecoderlayer(
 | 
				
			||||||
                weights,
 | 
					            weights,
 | 
				
			||||||
                num_heads=num_heads,
 | 
					            num_heads=num_heads,
 | 
				
			||||||
                num_key_value_heads=num_key_value_heads,
 | 
					            num_key_value_heads=num_key_value_heads,
 | 
				
			||||||
                cached_cos=cached_cos,
 | 
					            cached_cos=cached_cos,
 | 
				
			||||||
                cached_sin=cached_sin,
 | 
					            cached_sin=cached_sin,
 | 
				
			||||||
                layer_norm_0=layer_norm_0,
 | 
					            layer_norm_0=layer_norm_0,
 | 
				
			||||||
                layer_norm_1=layer_norm_1,
 | 
					            layer_norm_1=layer_norm_1,
 | 
				
			||||||
                q_bias=attn_layer.q_proj.bias.to(torch.float16),
 | 
					            q_bias=attn_layer.q_proj.bias.to(torch.float16),
 | 
				
			||||||
                k_bias=attn_layer.k_proj.bias.to(torch.float16),
 | 
					            k_bias=attn_layer.k_proj.bias.to(torch.float16),
 | 
				
			||||||
                v_bias=attn_layer.v_proj.bias.to(torch.float16),
 | 
					            v_bias=attn_layer.v_proj.bias.to(torch.float16),
 | 
				
			||||||
                layer_idx=layer_idx,
 | 
					            layer_idx=layer_idx,
 | 
				
			||||||
                rms_norm_eps=rms_norm_eps,
 | 
					            rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
                intermediate_size=intermediate_size,
 | 
					            intermediate_size=intermediate_size,
 | 
				
			||||||
                max_seq_len=max_output_len,
 | 
					            max_seq_len=max_output_len,
 | 
				
			||||||
                transpose_value=transpose_value_cache,
 | 
					            transpose_value=transpose_value_cache,
 | 
				
			||||||
            )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            layer_weights.extend(weights)
 | 
					        layer_weights.extend(weights)
 | 
				
			||||||
            input_layer_norm_weights.append(layer_norm_0)
 | 
					        input_layer_norm_weights.append(layer_norm_0)
 | 
				
			||||||
            post_attn_layernorm_weights.append(layer_norm_1)
 | 
					        post_attn_layernorm_weights.append(layer_norm_1)
 | 
				
			||||||
            model.model.layers[layer_idx] = new_decoderlayer
 | 
					        model.model.layers[layer_idx] = new_decoderlayer
 | 
				
			||||||
            deocderlayers.append(new_decoderlayer)
 | 
					        deocderlayers.append(new_decoderlayer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("finish creating all decode layers in prefill")
 | 
					    print("finish creating all decode layers in prefill")
 | 
				
			||||||
    result_queue.put("loading finish")
 | 
					    result_queue.put("loading finish")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model.config.intermediate_size == 18944:
 | 
					 | 
				
			||||||
        # for qwen2-7b
 | 
					 | 
				
			||||||
        from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
 | 
					 | 
				
			||||||
        from ipex_llm.transformers.npu_models.convert_mp import convert_forward
 | 
					 | 
				
			||||||
        qwen2_attention_forward = generate_qwen2_attention_forward(
 | 
					 | 
				
			||||||
            max_seq_len=max_output_len,
 | 
					 | 
				
			||||||
            transpose_value=transpose_value_cache
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        convert_forward(model, Qwen2Attention, qwen2_attention_forward)
 | 
					 | 
				
			||||||
        from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
 | 
					 | 
				
			||||||
        convert_forward(model, Qwen2MLP, split_mlp_forward)
 | 
					 | 
				
			||||||
        deocderlayers = model.model.layers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        result = input_queue.get()
 | 
					        result = input_queue.get()
 | 
				
			||||||
| 
						 | 
					@ -1136,81 +1135,3 @@ def qwen2_casullm_forward(
 | 
				
			||||||
        hidden_states=outputs.hidden_states,
 | 
					        hidden_states=outputs.hidden_states,
 | 
				
			||||||
        attentions=outputs.attentions,
 | 
					        attentions=outputs.attentions,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
 | 
					 | 
				
			||||||
import math
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def generate_qwen2_attention_forward(max_seq_len, transpose_value):
 | 
					 | 
				
			||||||
    def qwen2_attention_forward(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        hidden_states: torch.Tensor,
 | 
					 | 
				
			||||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
					 | 
				
			||||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
					 | 
				
			||||||
        past_key_value: Optional[Cache] = None,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
        use_cache: bool = False,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
					 | 
				
			||||||
        bsz, q_len, _ = hidden_states.size()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        query_states = self.q_proj(hidden_states)
 | 
					 | 
				
			||||||
        key_states = self.k_proj(hidden_states)
 | 
					 | 
				
			||||||
        value_states = self.v_proj(hidden_states)
 | 
					 | 
				
			||||||
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
					 | 
				
			||||||
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
 | 
					 | 
				
			||||||
                                     self.head_dim).transpose(1, 2)
 | 
					 | 
				
			||||||
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
 | 
					 | 
				
			||||||
                                         self.head_dim).transpose(1, 2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        kv_seq_len = key_states.shape[-2]
 | 
					 | 
				
			||||||
        if past_key_value is not None:
 | 
					 | 
				
			||||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					 | 
				
			||||||
                                                        cos, sin, position_ids)
 | 
					 | 
				
			||||||
        cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if past_key_value is not None:
 | 
					 | 
				
			||||||
            if transpose_value:
 | 
					 | 
				
			||||||
                value_states = value_states.transpose(-1, -2)
 | 
					 | 
				
			||||||
            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					 | 
				
			||||||
                                                             self.layer_idx, cache_kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					 | 
				
			||||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_weights = None
 | 
					 | 
				
			||||||
        if query_states.size(2) == key_states.size(2):
 | 
					 | 
				
			||||||
            # first token
 | 
					 | 
				
			||||||
            from intel_npu_acceleration_library.functional import scaled_dot_product_attention
 | 
					 | 
				
			||||||
            attn_output = scaled_dot_product_attention(
 | 
					 | 
				
			||||||
                query_states,
 | 
					 | 
				
			||||||
                key_states,
 | 
					 | 
				
			||||||
                value_states,
 | 
					 | 
				
			||||||
                attn_mask=attention_mask,
 | 
					 | 
				
			||||||
                is_causal=q_len > 1 and bsz == 1,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            attn_weights = torch.matmul(query_states,
 | 
					 | 
				
			||||||
                                        key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					 | 
				
			||||||
            if attention_mask is not None:
 | 
					 | 
				
			||||||
                attn_weights = attn_weights + attention_mask
 | 
					 | 
				
			||||||
            # upcast attention to fp32
 | 
					 | 
				
			||||||
            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
					 | 
				
			||||||
                                                       dtype=torch.float32).to(query_states.dtype)
 | 
					 | 
				
			||||||
            attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
					 | 
				
			||||||
                                                       training=self.training)
 | 
					 | 
				
			||||||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
					 | 
				
			||||||
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_output = self.o_proj(attn_output)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not output_attentions:
 | 
					 | 
				
			||||||
            attn_weights = None
 | 
					 | 
				
			||||||
        return attn_output, attn_weights, past_key_value
 | 
					 | 
				
			||||||
    return qwen2_attention_forward
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue