add position_ids and fuse embedding for falcon (#9242)
* add position_ids for falcon * add cpu * add cpu * add license
This commit is contained in:
parent
7f66bc5c14
commit
0c5055d38c
2 changed files with 92 additions and 16 deletions
|
|
@ -18,9 +18,11 @@
|
||||||
#
|
#
|
||||||
# This file is adapted from
|
# This file is adapted from
|
||||||
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
|
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
|
||||||
|
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
|
||||||
#
|
#
|
||||||
# Apache 2.0 license
|
# Apache 2.0 license
|
||||||
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
|
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
#
|
#
|
||||||
|
|
@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
return self.cos_cached, self.sin_cached
|
return self.cos_cached, self.sin_cached
|
||||||
|
|
||||||
# def forward(self, q, k):
|
# def forward(self, q, k):
|
||||||
def forward(self, q, k, seq_len):
|
def forward(self, q, k, past_key_values_length, position_ids):
|
||||||
# batch, seq_len, head_dim = q.shape
|
# batch, seq_len, head_dim = q.shape
|
||||||
_,q_len,_ = q.shape
|
_,q_len,_ = q.shape
|
||||||
|
seq_len = q_len + past_key_values_length
|
||||||
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
||||||
cos = cos[:,-q_len:]
|
cos = cos.squeeze(0)[position_ids]
|
||||||
sin = sin[:,-q_len:]
|
sin = sin.squeeze(0)[position_ids]
|
||||||
|
|
||||||
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
|
@ -268,6 +271,7 @@ class Attention(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: torch.Tensor,
|
alibi: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
|
@ -289,12 +293,23 @@ class Attention(nn.Module):
|
||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
|
|
||||||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||||
_, seq_len, _ = query_layer.shape
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
if layer_past is not None:
|
|
||||||
_, seq_len_past, _ = layer_past[0].shape
|
|
||||||
|
|
||||||
seq_len = seq_len + seq_len_past
|
use_fuse_rope = query_layer.device.type == "xpu"
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
||||||
|
if use_fuse_rope:
|
||||||
|
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
||||||
|
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
||||||
|
key_layer,
|
||||||
|
position_ids,
|
||||||
|
"gpt_neox")
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
|
else:
|
||||||
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
past_key, past_value = layer_past
|
||||||
|
|
@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: torch.Tensor,
|
alibi: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
|
@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
|
||||||
layernorm_output,
|
layernorm_output,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|
@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|
@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
|
||||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
else:
|
else:
|
||||||
alibi = None
|
alibi = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._prepare_attn_mask(
|
causal_mask = self._prepare_attn_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|
@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
# past: Optional[torch.Tensor] = None,
|
# past: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
|
|
@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||||
past_key_values = self._convert_to_rw_cache(past_key_values)
|
past_key_values = self._convert_to_rw_cache(past_key_values)
|
||||||
|
|
||||||
|
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
||||||
|
if not self.transformer.alibi and attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
# "past_key_values": past,
|
# "past_key_values": past,
|
||||||
|
"position_ids": position_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|
@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,11 @@
|
||||||
#
|
#
|
||||||
# This file is adapted from
|
# This file is adapted from
|
||||||
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
|
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
|
||||||
|
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
|
||||||
#
|
#
|
||||||
# Apache 2.0 license
|
# Apache 2.0 license
|
||||||
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
|
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
#
|
#
|
||||||
|
|
@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
return self.cos_cached, self.sin_cached
|
return self.cos_cached, self.sin_cached
|
||||||
|
|
||||||
# def forward(self, q, k):
|
# def forward(self, q, k):
|
||||||
def forward(self, q, k, seq_len):
|
def forward(self, q, k, past_key_values_length, position_ids):
|
||||||
# batch, seq_len, head_dim = q.shape
|
# batch, seq_len, head_dim = q.shape
|
||||||
_,q_len,_ = q.shape
|
_,q_len,_ = q.shape
|
||||||
|
seq_len = q_len + past_key_values_length
|
||||||
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
||||||
cos = cos[:,-q_len:]
|
cos = cos.squeeze(0)[position_ids]
|
||||||
sin = sin[:,-q_len:]
|
sin = sin.squeeze(0)[position_ids]
|
||||||
|
|
||||||
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
|
@ -268,6 +271,7 @@ class Attention(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: torch.Tensor,
|
alibi: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
|
@ -289,12 +293,23 @@ class Attention(nn.Module):
|
||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
|
|
||||||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||||
_, seq_len, _ = query_layer.shape
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
if layer_past is not None:
|
|
||||||
_, seq_len_past, _ = layer_past[0].shape
|
|
||||||
|
|
||||||
seq_len = seq_len + seq_len_past
|
use_fuse_rope = query_layer.device.type == "xpu"
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
||||||
|
if use_fuse_rope:
|
||||||
|
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
||||||
|
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
||||||
|
key_layer,
|
||||||
|
position_ids,
|
||||||
|
"gpt_neox")
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
|
else:
|
||||||
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
past_key, past_value = layer_past
|
||||||
|
|
@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: torch.Tensor,
|
alibi: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
|
@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
|
||||||
layernorm_output,
|
layernorm_output,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|
@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|
@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
|
||||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
else:
|
else:
|
||||||
alibi = None
|
alibi = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._prepare_attn_mask(
|
causal_mask = self._prepare_attn_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|
@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
# past: Optional[torch.Tensor] = None,
|
# past: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
|
|
@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||||
past_key_values = self._convert_to_rw_cache(past_key_values)
|
past_key_values = self._convert_to_rw_cache(past_key_values)
|
||||||
|
|
||||||
|
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
||||||
|
if not self.transformer.alibi and attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
# "past_key_values": past,
|
# "past_key_values": past,
|
||||||
|
"position_ids": position_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|
@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue