Support IR and blob format for llama level0 pipeline (#12251)

This commit is contained in:
binbin Deng 2024-10-23 16:02:35 +08:00 committed by GitHub
parent 578aef245d
commit 567b77a76b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 48 additions and 40 deletions

View file

@ -396,7 +396,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask, attention_mask,
position_ids, position_ids.to(torch.float16),
) )
for i in range(self.intra_stages): for i in range(self.intra_stages):
@ -502,7 +502,7 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
seq_len = hidden_states.shape[1] seq_len = hidden_states.shape[1]
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2 inputs, self.op_parameters, backend_cls, self.op_id, replica=2

View file

@ -106,31 +106,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
# Self Attention # Self Attention
if mode == "decode": if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64)
else: else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64)
position_ids = self.create_input_op((self.batch_size, self.seq_len)) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers
if input_layernorm_weights is None: if input_layernorm_weights is None:
input_layernorm_weights = [] input_layernorm_weights = []
@ -156,6 +138,27 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights] input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]
past_keys = []
past_values = []
if mode == "decode":
for i in range(num_layers):
past_key = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
if transpose_value:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
)
else:
past_value = self.create_cache_op(
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
)
past_keys.append(past_key)
past_values.append(past_value)
else:
past_keys = [None] * num_layers
past_values = [None] * num_layers
hidden_states = input hidden_states = input
curr_key_values = [] curr_key_values = []
@ -310,8 +313,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask, attention_mask.to(torch.int64),
position_ids, position_ids.to(torch.int64),
) )
for i in range(self.intra_stages): for i in range(self.intra_stages):
@ -419,7 +422,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
seq_len = hidden_states.shape[1] seq_len = hidden_states.shape[1]
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2 inputs, self.op_parameters, backend_cls, self.op_id, replica=2
@ -544,9 +549,9 @@ def run_decode(
pad_mask = (0, pad_len) pad_mask = (0, pad_len)
padded_causal_mask = F.pad( padded_causal_mask = F.pad(
causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
) )
padded_causal_mask[:, :, :, -1] = 0.0 padded_causal_mask[:, :, :, -1] = 0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
@ -796,9 +801,9 @@ class PrefillRunner:
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0) position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad( attention_mask = F.pad(
attention_mask.to(torch.float16), attention_mask.to(torch.int64),
(0, pad_len, 0, pad_len), (0, pad_len, 0, pad_len),
value=torch.finfo(torch.float16).min, value=torch.iinfo(torch.int64).min,
) )
args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position) args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position)

View file

@ -335,7 +335,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask, attention_mask,
position_ids, position_ids.to(torch.float16),
) )
for i in range(self.intra_stages): for i in range(self.intra_stages):
@ -445,7 +445,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
seq_len = hidden_states.shape[1] seq_len = hidden_states.shape[1]
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2 inputs, self.op_parameters, backend_cls, self.op_id, replica=2

View file

@ -54,7 +54,8 @@ def run_model(
# Reshape input # Reshape input
input_dtype = x[0].dtype input_dtype = x[0].dtype
x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else
set_contiguous(elem).to(torch.float16).numpy() for elem in x]
op_args = [] op_args = []
op_args_flatten = [] op_args_flatten = []
for w in weights: for w in weights:
@ -279,6 +280,7 @@ class LLMBaseNNFactory(NNFactory):
attn_weight = self.matmul(query_states, key_states, False, True) / ( attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim) math.sqrt(head_dim)
) )
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1) attn_weight = self.softmax(attn_weight, -1)
@ -476,13 +478,13 @@ class LLMBaseNNFactory(NNFactory):
self.cache_parameter_ops.append(op) self.cache_parameter_ops.append(op)
return op return op
def create_input_op(self, shape): def create_input_op(self, shape, dtype=np.float16):
invalidInputError(len(self.cache_parameter_ops) == 0, invalidInputError(len(self.cache_parameter_ops) == 0,
"create_input_op should be called before any create_cache_op") "create_input_op should be called before any create_cache_op")
invalidInputError(len(self.linear_ops) == 0, invalidInputError(len(self.linear_ops) == 0,
"create_input_op should be called before any linear op") "create_input_op should be called before any linear op")
op = super().parameter(shape) op = super().parameter(shape, dtype)
self.input_ops.append(op) self.input_ops.append(op)
return op return op
@ -563,7 +565,8 @@ class LLMBaseNNFactory(NNFactory):
@staticmethod @staticmethod
def run_decoders(inputs, decoders, models_ptr=None): def run_decoders(inputs, decoders, models_ptr=None):
x_np = [elem.to(torch.float16).numpy() for elem in inputs] x_np = [elem.numpy() if elem.dtype == torch.int64 else
elem.to(torch.float16).numpy() for elem in inputs]
num_decoders = len(decoders) num_decoders = len(decoders)
num_inputs = len(x_np) num_inputs = len(x_np)

View file

@ -413,7 +413,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask, attention_mask,
position_ids, position_ids.to(torch.float16),
) )
for i in range(self.intra_stages): for i in range(self.intra_stages):
@ -530,7 +530,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
seq_len = hidden_states.shape[1] seq_len = hidden_states.shape[1]
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
inputs += (self.q_bias, self.k_bias, self.v_bias) inputs += (self.q_bias, self.k_bias, self.v_bias)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(