optimize yuan 2.0 again (#10252)

This commit is contained in:
Yishuo Wang 2024-02-27 14:51:42 +08:00 committed by GitHub
parent 03b9c4930a
commit b4fa4ab46f
2 changed files with 51 additions and 44 deletions

View file

@ -1187,7 +1187,6 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
from bigdl.llm.transformers.models.yuan import yuan_localized_filtering_forward
convert_forward(model,
module.YuanAttention,
yuan_attention_forward
@ -1196,7 +1195,4 @@ def _optimize_post(model, lightweight_bmm=False):
module.YuanMLP,
yuan_mlp_forward
)
convert_forward(model,
module.LocalizedFiltering,
yuan_localized_filtering_forward)
return model

View file

@ -54,22 +54,39 @@ def yuan_localized_filtering_forward(
self,
inputs: torch.Tensor,
before_hidden_states: torch.Tensor,
dtype: torch.dtype,
):
if self.conv1.weight.dtype != torch.half:
self.half()
inputs = inputs.half()
if before_hidden_states is not None:
before_hidden_states = before_hidden_states.half()
invalidInputError(self.lf_conv2d_num_pad == 1, "padding must be 1")
if self.training:
lf_output = self._train_forward(inputs)
invalidInputError(not self.training, ("training is not supported for now, "
"please call model.eval() before inference"))
if before_hidden_states is None:
inputs = inputs.half()
lf_output = self._inference_forward(inputs, None)
else:
lf_output = self._inference_forward(inputs, before_hidden_states)
# only change next token logic
bsz, seq_len, embed_dim = inputs.size()
seq_len_before, _, _ = before_hidden_states.size()
invalidInputError(seq_len == 1 and seq_len_before == 3,
f"wrong sequence length: {seq_len} {seq_len_before}")
lf_output = lf_output.to(inputs.dtype)
residual = before_hidden_states[-1:, :, :]
inputs = before_hidden_states.view(3, 1, bsz, embed_dim).permute(2, 3, 0, 1)
output1 = self.conv1(inputs)
output2 = self.conv2(output1[:, :, 1:-1, :])
output2 = output2[:, :, 1:-1, :]
output2 = output2.view(1, bsz, embed_dim)
invalidInputError(output2.shape == residual.shape,
f"wrong shape: {output2.shape} {residual.shape}")
lf_output = self.output_layernorm(output2 + residual)
lf_output = lf_output.transpose(0, 1)
lf_output = lf_output.to(dtype)
return lf_output
@ -137,44 +154,38 @@ def yuan_attention_forward(
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
if use_cache:
if past_key_value is None:
inference_hidden_states_memory = torch.empty(bsz, 2,
hidden_states.shape[2],
dtype=hidden_states.dtype)
is_first_step = True
else:
before_hidden_states = past_key_value[2]
invalidInputError(use_cache, "use_cache=True is needed")
invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
if use_cache:
if is_first_step:
if q_len >= 2:
inference_hidden_states_memory = hidden_states[:, -2:, :]
else:
inference_hidden_states_memory[:, :, :] = 0
inference_hidden_states_memory[:, -1:, :] = hidden_states[:, -1:, :]
if past_key_value is None:
is_first_step = True
if q_len >= 2:
before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
else:
hidden_states_tmp = before_hidden_states[:, -1:, :]
inference_hidden_states_memory = torch.cat((hidden_states_tmp,
hidden_states), dim=1)
before_hidden_states = torch.zeros(2, bsz, self.hidden_size,
dtype=torch.half, device=hidden_states.device)
before_hidden_states[-1:, :, :] = hidden_states[:, -1:, :].transpose(0, 1)
else:
before_hidden_states = past_key_value[2]
this_hidden_states = torch.cat([
before_hidden_states,
hidden_states.transpose(0, 1).half(),
], dim=0)
before_hidden_states = this_hidden_states[-2:, :, ]
value_states = \
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if self.use_shareqk:
# use_shareqk is disabled for now
qk_states = self.qk_proj(hidden_states).view(bsz, q_len, self.num_heads*self.head_dim)
query_key = qk_states.unsqueeze(2) * self.qk_weight + self.qk_bias
query_states, key_states = torch.unbind(query_key, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if is_first_step:
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
None, hidden_states.dtype)
else:
hidden_states = self.lf_gate(hidden_states, before_hidden_states)
qk_states = self.merged_qk_proj(hidden_states)
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
this_hidden_states, hidden_states.dtype)
qk_states = self.merged_qk_proj(hidden_states)
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@ -229,7 +240,7 @@ def yuan_attention_forward(
value_states = new_value_states
past_key_value = \
(key_states, value_states, inference_hidden_states_memory) if use_cache else None
(key_states, value_states, before_hidden_states) if use_cache else None
attn_weights = \
torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)