From c876d9b5cadc0a0a4b36a0d714bebe09a34aefe1 Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:16:31 +0800 Subject: [PATCH] Support for MPT rotary embedding (#10208) --- .../src/bigdl/llm/transformers/models/mpt.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/mpt.py b/python/llm/src/bigdl/llm/transformers/models/mpt.py index 7b32a4bc..a09ef771 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mpt.py +++ b/python/llm/src/bigdl/llm/transformers/models/mpt.py @@ -30,7 +30,9 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None, - attention_mask=None, is_causal=True, needs_weights=False): + attention_mask=None, is_causal=True, + needs_weights=False, rotary_emb_w_meta_info=None, + **kwargs): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) @@ -40,6 +42,43 @@ def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) + + if rotary_emb_w_meta_info is not None: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) + + if rotary_emb_w_meta_info['impl'] == 'dail': + value = value.view(bsz, seqlen, -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb(query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + elif rotary_emb_w_meta_info['impl'] == 'hf': + (cos, sin) = rotary_emb(value, seq_len) + if is_transformers_version_gte('4.36'): + query, key = apply_rotary_pos_emb(query, + key, + cos, + sin, + offset_info, + unsqueeze_dim=2) + else: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, + offset_info) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + (context, attn_weights, past_key_value) = \ mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads, past_key_value=past_key_value,