diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 558ebda9..68d7c63f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -85,23 +85,23 @@ def llama_attention_forward_4_31( bsz, q_len, _ = hidden_states.size() device = hidden_states.device - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) - // self.pretraining_tp, dim=0) + // self.config.pretraining_tp, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) - for i in range(self.pretraining_tp)] + for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) - for i in range(self.pretraining_tp)] + for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) - for i in range(self.pretraining_tp)] + for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: @@ -194,11 +194,12 @@ def llama_attention_forward_4_31( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, + dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp)]) + for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output)