add position_ids and fuse embedding for falcon (#9242)

* add position_ids for falcon

* add cpu

* add cpu

* add license
This commit is contained in:
Xin Qiu 2023-10-24 09:58:20 +08:00 committed by GitHub
parent 7f66bc5c14
commit 0c5055d38c
2 changed files with 92 additions and 16 deletions

View file

@ -18,9 +18,11 @@
#
# This file is adapted from
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
#
# Apache 2.0 license
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
# ===========================================================================
#
@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
return self.cos_cached, self.sin_cached
# def forward(self, q, k):
def forward(self, q, k, seq_len):
def forward(self, q, k, past_key_values_length, position_ids):
# batch, seq_len, head_dim = q.shape
_,q_len,_ = q.shape
seq_len = q_len + past_key_values_length
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
cos = cos[:,-q_len:]
sin = sin[:,-q_len:]
cos = cos.squeeze(0)[position_ids]
sin = sin.squeeze(0)[position_ids]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
@ -268,6 +271,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@ -289,12 +293,23 @@ class Attention(nn.Module):
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
_, seq_len, _ = query_layer.shape
if layer_past is not None:
_, seq_len_past, _ = layer_past[0].shape
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_len = seq_len + seq_len_past
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
use_fuse_rope = query_layer.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
if use_fuse_rope:
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
key_layer,
position_ids,
"gpt_neox")
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
else:
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
if layer_past is not None:
past_key, past_value = layer_past
@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
layernorm_output,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask(
attention_mask,
@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
)
else:
@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
# past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
# "past_key_values": past,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,

View file

@ -18,9 +18,11 @@
#
# This file is adapted from
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
#
# Apache 2.0 license
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
# ===========================================================================
#
@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
return self.cos_cached, self.sin_cached
# def forward(self, q, k):
def forward(self, q, k, seq_len):
def forward(self, q, k, past_key_values_length, position_ids):
# batch, seq_len, head_dim = q.shape
_,q_len,_ = q.shape
seq_len = q_len + past_key_values_length
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
cos = cos[:,-q_len:]
sin = sin[:,-q_len:]
cos = cos.squeeze(0)[position_ids]
sin = sin.squeeze(0)[position_ids]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
@ -268,6 +271,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@ -289,12 +293,23 @@ class Attention(nn.Module):
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
_, seq_len, _ = query_layer.shape
if layer_past is not None:
_, seq_len_past, _ = layer_past[0].shape
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_len = seq_len + seq_len_past
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
use_fuse_rope = query_layer.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
if use_fuse_rope:
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
key_layer,
position_ids,
"gpt_neox")
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
else:
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
if layer_past is not None:
past_key, past_value = layer_past
@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
layernorm_output,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask(
attention_mask,
@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
)
else:
@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
# past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
# "past_key_values": past,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,