Support IR and blob format for llama level0 pipeline (#12251)
This commit is contained in:
		
							parent
							
								
									578aef245d
								
							
						
					
					
						commit
						567b77a76b
					
				
					 5 changed files with 48 additions and 40 deletions
				
			
		| 
						 | 
				
			
			@ -396,7 +396,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            position_ids.to(torch.float16),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
| 
						 | 
				
			
			@ -502,7 +502,7 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        seq_len = hidden_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -106,31 +106,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
 | 
			
		||||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len))
 | 
			
		||||
        past_keys = []
 | 
			
		||||
        past_values = []
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            for i in range(num_layers):
 | 
			
		||||
                past_key = self.create_cache_op(
 | 
			
		||||
                    (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
 | 
			
		||||
                )
 | 
			
		||||
                if transpose_value:
 | 
			
		||||
                    past_value = self.create_cache_op(
 | 
			
		||||
                        (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    past_value = self.create_cache_op(
 | 
			
		||||
                        (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
 | 
			
		||||
                    )
 | 
			
		||||
                past_keys.append(past_key)
 | 
			
		||||
                past_values.append(past_value)
 | 
			
		||||
        else:
 | 
			
		||||
            past_keys = [None] * num_layers
 | 
			
		||||
            past_values = [None] * num_layers
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
        if input_layernorm_weights is None:
 | 
			
		||||
            input_layernorm_weights = []
 | 
			
		||||
| 
						 | 
				
			
			@ -156,6 +138,27 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
 | 
			
		||||
            post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]
 | 
			
		||||
 | 
			
		||||
        past_keys = []
 | 
			
		||||
        past_values = []
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            for i in range(num_layers):
 | 
			
		||||
                past_key = self.create_cache_op(
 | 
			
		||||
                    (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
 | 
			
		||||
                )
 | 
			
		||||
                if transpose_value:
 | 
			
		||||
                    past_value = self.create_cache_op(
 | 
			
		||||
                        (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    past_value = self.create_cache_op(
 | 
			
		||||
                        (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
 | 
			
		||||
                    )
 | 
			
		||||
                past_keys.append(past_key)
 | 
			
		||||
                past_values.append(past_value)
 | 
			
		||||
        else:
 | 
			
		||||
            past_keys = [None] * num_layers
 | 
			
		||||
            past_values = [None] * num_layers
 | 
			
		||||
 | 
			
		||||
        hidden_states = input
 | 
			
		||||
 | 
			
		||||
        curr_key_values = []
 | 
			
		||||
| 
						 | 
				
			
			@ -310,8 +313,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            position_ids.to(torch.int64),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
| 
						 | 
				
			
			@ -419,7 +422,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        seq_len = hidden_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                  attention_mask.to(torch.int64),
 | 
			
		||||
                  position_ids.to(torch.int64))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
			
		||||
| 
						 | 
				
			
			@ -544,9 +549,9 @@ def run_decode(
 | 
			
		|||
 | 
			
		||||
                pad_mask = (0, pad_len)
 | 
			
		||||
                padded_causal_mask = F.pad(
 | 
			
		||||
                    causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
 | 
			
		||||
                    causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
 | 
			
		||||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -796,9 +801,9 @@ class PrefillRunner:
 | 
			
		|||
        hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
 | 
			
		||||
        position_ids = F.pad(position_ids, (0, pad_len), value=0)
 | 
			
		||||
        attention_mask = F.pad(
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            (0, pad_len, 0, pad_len),
 | 
			
		||||
            value=torch.finfo(torch.float16).min,
 | 
			
		||||
            value=torch.iinfo(torch.int64).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -335,7 +335,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            position_ids.to(torch.float16),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
| 
						 | 
				
			
			@ -445,7 +445,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        seq_len = hidden_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
            inputs, self.op_parameters, backend_cls, self.op_id, replica=2
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,7 +54,8 @@ def run_model(
 | 
			
		|||
 | 
			
		||||
    # Reshape input
 | 
			
		||||
    input_dtype = x[0].dtype
 | 
			
		||||
    x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x]
 | 
			
		||||
    x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else
 | 
			
		||||
            set_contiguous(elem).to(torch.float16).numpy() for elem in x]
 | 
			
		||||
    op_args = []
 | 
			
		||||
    op_args_flatten = []
 | 
			
		||||
    for w in weights:
 | 
			
		||||
| 
						 | 
				
			
			@ -279,6 +280,7 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
            math.sqrt(head_dim)
 | 
			
		||||
        )
 | 
			
		||||
        attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
        attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
        attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
        attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
| 
						 | 
				
			
			@ -476,13 +478,13 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        self.cache_parameter_ops.append(op)
 | 
			
		||||
        return op
 | 
			
		||||
 | 
			
		||||
    def create_input_op(self, shape):
 | 
			
		||||
    def create_input_op(self, shape, dtype=np.float16):
 | 
			
		||||
        invalidInputError(len(self.cache_parameter_ops) == 0,
 | 
			
		||||
                          "create_input_op should be called before any create_cache_op")
 | 
			
		||||
        invalidInputError(len(self.linear_ops) == 0,
 | 
			
		||||
                          "create_input_op should be called before any linear op")
 | 
			
		||||
 | 
			
		||||
        op = super().parameter(shape)
 | 
			
		||||
        op = super().parameter(shape, dtype)
 | 
			
		||||
        self.input_ops.append(op)
 | 
			
		||||
        return op
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -563,7 +565,8 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def run_decoders(inputs, decoders, models_ptr=None):
 | 
			
		||||
        x_np = [elem.to(torch.float16).numpy() for elem in inputs]
 | 
			
		||||
        x_np = [elem.numpy() if elem.dtype == torch.int64 else
 | 
			
		||||
                elem.to(torch.float16).numpy() for elem in inputs]
 | 
			
		||||
 | 
			
		||||
        num_decoders = len(decoders)
 | 
			
		||||
        num_inputs = len(x_np)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -413,7 +413,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            position_ids.to(torch.float16),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for i in range(self.intra_stages):
 | 
			
		||||
| 
						 | 
				
			
			@ -530,7 +530,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        seq_len = hidden_states.shape[1]
 | 
			
		||||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids)
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        inputs += (self.q_bias, self.k_bias, self.v_bias)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue