Support latest transformer version (#8923)

* Support latest transformer version

* fix style
This commit is contained in:
Yang Wang 2023-09-08 10:01:32 +08:00 committed by GitHub
parent 25428b22b4
commit ee98cdd85c

View file

@ -85,23 +85,23 @@ def llama_attention_forward_4_31(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
// self.pretraining_tp, dim=0)
// self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i])
for i in range(self.pretraining_tp)]
for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i])
for i in range(self.pretraining_tp)]
for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.pretraining_tp)]
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
@ -194,11 +194,12 @@ def llama_attention_forward_4_31(
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)])
for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)