add position_ids and fuse embedding for falcon (#9242)
* add position_ids for falcon * add cpu * add cpu * add license
This commit is contained in:
		
							parent
							
								
									7f66bc5c14
								
							
						
					
					
						commit
						0c5055d38c
					
				
					 2 changed files with 92 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -18,9 +18,11 @@
 | 
			
		|||
#
 | 
			
		||||
# This file is adapted from
 | 
			
		||||
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
 | 
			
		||||
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
 | 
			
		||||
#
 | 
			
		||||
# Apache 2.0 license
 | 
			
		||||
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
 | 
			
		||||
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
# 
 | 
			
		||||
| 
						 | 
				
			
			@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
 | 
			
		|||
        return self.cos_cached, self.sin_cached
 | 
			
		||||
 | 
			
		||||
    # def forward(self, q, k):
 | 
			
		||||
    def forward(self, q, k, seq_len):
 | 
			
		||||
    def forward(self, q, k, past_key_values_length, position_ids):
 | 
			
		||||
        # batch, seq_len, head_dim = q.shape
 | 
			
		||||
        _,q_len,_ = q.shape
 | 
			
		||||
        seq_len = q_len + past_key_values_length
 | 
			
		||||
        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
 | 
			
		||||
        cos = cos[:,-q_len:]
 | 
			
		||||
        sin = sin[:,-q_len:]
 | 
			
		||||
        cos = cos.squeeze(0)[position_ids]
 | 
			
		||||
        sin = sin.squeeze(0)[position_ids]
 | 
			
		||||
 | 
			
		||||
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -268,6 +271,7 @@ class Attention(nn.Module):
 | 
			
		|||
        hidden_states: torch.Tensor,
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
| 
						 | 
				
			
			@ -289,12 +293,23 @@ class Attention(nn.Module):
 | 
			
		|||
        value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
			
		||||
 | 
			
		||||
        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
			
		||||
        _, seq_len, _ = query_layer.shape
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            _, seq_len_past, _ = layer_past[0].shape
 | 
			
		||||
        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
			
		||||
 | 
			
		||||
            seq_len = seq_len + seq_len_past
 | 
			
		||||
        query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
			
		||||
        use_fuse_rope = query_layer.device.type == "xpu"
 | 
			
		||||
        use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            # resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
 | 
			
		||||
            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
			
		||||
            from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
            query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
 | 
			
		||||
                                                                       key_layer,
 | 
			
		||||
                                                                       position_ids,
 | 
			
		||||
                                                                       "gpt_neox")
 | 
			
		||||
            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
			
		||||
            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
			
		||||
        else:
 | 
			
		||||
            query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
 | 
			
		||||
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            past_key, past_value = layer_past
 | 
			
		||||
| 
						 | 
				
			
			@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
 | 
			
		|||
        hidden_states: torch.Tensor,
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
| 
						 | 
				
			
			@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
 | 
			
		|||
            layernorm_output,
 | 
			
		||||
            layer_past=layer_past,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            alibi=alibi,
 | 
			
		||||
            head_mask=head_mask,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
| 
						 | 
				
			
			@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.LongTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
            alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
 | 
			
		||||
        else:
 | 
			
		||||
            alibi = None
 | 
			
		||||
            if position_ids is None:
 | 
			
		||||
                device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
                position_ids = torch.arange(
 | 
			
		||||
                    past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
                )
 | 
			
		||||
                position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        causal_mask = self._prepare_attn_mask(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
                    hidden_states,
 | 
			
		||||
                    alibi,
 | 
			
		||||
                    causal_mask,
 | 
			
		||||
                    position_ids,
 | 
			
		||||
                    head_mask[i],
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
| 
						 | 
				
			
			@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
                    hidden_states,
 | 
			
		||||
                    layer_past=layer_past,
 | 
			
		||||
                    attention_mask=causal_mask,
 | 
			
		||||
                    position_ids=position_ids,
 | 
			
		||||
                    head_mask=head_mask[i],
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
| 
						 | 
				
			
			@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
        # past: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict:
 | 
			
		||||
        # only last token for input_ids if past is not None
 | 
			
		||||
| 
						 | 
				
			
			@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
            if past_key_values[0][0].shape[0] == input_ids.shape[0]:
 | 
			
		||||
                past_key_values = self._convert_to_rw_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
       # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
 | 
			
		||||
        if not self.transformer.alibi and attention_mask is not None and position_ids is None:
 | 
			
		||||
            # create position_ids on the fly for batch generation
 | 
			
		||||
            position_ids = attention_mask.long().cumsum(-1) - 1
 | 
			
		||||
            position_ids.masked_fill_(attention_mask == 0, 1)
 | 
			
		||||
            if past_key_values:
 | 
			
		||||
                position_ids = position_ids[:, -input_ids.shape[1] :]
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "input_ids": input_ids,
 | 
			
		||||
            # "past_key_values": past,
 | 
			
		||||
            "position_ids": position_ids,
 | 
			
		||||
            "past_key_values": past_key_values,
 | 
			
		||||
            "use_cache": kwargs.get("use_cache"),
 | 
			
		||||
            "attention_mask": attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        labels: Optional[torch.Tensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            head_mask=head_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,9 +18,11 @@
 | 
			
		|||
#
 | 
			
		||||
# This file is adapted from
 | 
			
		||||
# https://huggingface.co/tiiuae/falcon-7b-instruct/blob/c7f670a03d987254220f343c6b026ea0c5147185/modelling_RW.py
 | 
			
		||||
# and https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/falcon/modeling_falcon.py
 | 
			
		||||
#
 | 
			
		||||
# Apache 2.0 license
 | 
			
		||||
# https://huggingface.co/tiiuae/falcon-7b-instruct#license
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
 | 
			
		||||
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
# 
 | 
			
		||||
| 
						 | 
				
			
			@ -117,12 +119,13 @@ class RotaryEmbedding(torch.nn.Module):
 | 
			
		|||
        return self.cos_cached, self.sin_cached
 | 
			
		||||
 | 
			
		||||
    # def forward(self, q, k):
 | 
			
		||||
    def forward(self, q, k, seq_len):
 | 
			
		||||
    def forward(self, q, k, past_key_values_length, position_ids):
 | 
			
		||||
        # batch, seq_len, head_dim = q.shape
 | 
			
		||||
        _,q_len,_ = q.shape
 | 
			
		||||
        seq_len = q_len + past_key_values_length
 | 
			
		||||
        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
 | 
			
		||||
        cos = cos[:,-q_len:]
 | 
			
		||||
        sin = sin[:,-q_len:]
 | 
			
		||||
        cos = cos.squeeze(0)[position_ids]
 | 
			
		||||
        sin = sin.squeeze(0)[position_ids]
 | 
			
		||||
 | 
			
		||||
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -268,6 +271,7 @@ class Attention(nn.Module):
 | 
			
		|||
        hidden_states: torch.Tensor,
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
| 
						 | 
				
			
			@ -289,12 +293,23 @@ class Attention(nn.Module):
 | 
			
		|||
        value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
			
		||||
 | 
			
		||||
        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
			
		||||
        _, seq_len, _ = query_layer.shape
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            _, seq_len_past, _ = layer_past[0].shape
 | 
			
		||||
        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
			
		||||
 | 
			
		||||
            seq_len = seq_len + seq_len_past
 | 
			
		||||
        query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
 | 
			
		||||
        use_fuse_rope = query_layer.device.type == "xpu"
 | 
			
		||||
        use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
 | 
			
		||||
        if use_fuse_rope:
 | 
			
		||||
            # resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
 | 
			
		||||
            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
			
		||||
            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
			
		||||
            from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
            query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
 | 
			
		||||
                                                                       key_layer,
 | 
			
		||||
                                                                       position_ids,
 | 
			
		||||
                                                                       "gpt_neox")
 | 
			
		||||
            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
			
		||||
            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
			
		||||
        else:
 | 
			
		||||
            query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
 | 
			
		||||
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            past_key, past_value = layer_past
 | 
			
		||||
| 
						 | 
				
			
			@ -423,6 +438,7 @@ class DecoderLayer(nn.Module):
 | 
			
		|||
        hidden_states: torch.Tensor,
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
| 
						 | 
				
			
			@ -437,6 +453,7 @@ class DecoderLayer(nn.Module):
 | 
			
		|||
            layernorm_output,
 | 
			
		||||
            layer_past=layer_past,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            alibi=alibi,
 | 
			
		||||
            head_mask=head_mask,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
| 
						 | 
				
			
			@ -597,6 +614,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.LongTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -664,6 +682,12 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
            alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
 | 
			
		||||
        else:
 | 
			
		||||
            alibi = None
 | 
			
		||||
            if position_ids is None:
 | 
			
		||||
                device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
                position_ids = torch.arange(
 | 
			
		||||
                    past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
                )
 | 
			
		||||
                position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        causal_mask = self._prepare_attn_mask(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -696,6 +720,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
                    hidden_states,
 | 
			
		||||
                    alibi,
 | 
			
		||||
                    causal_mask,
 | 
			
		||||
                    position_ids,
 | 
			
		||||
                    head_mask[i],
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
| 
						 | 
				
			
			@ -703,6 +728,7 @@ class RWModel(RWPreTrainedModel):
 | 
			
		|||
                    hidden_states,
 | 
			
		||||
                    layer_past=layer_past,
 | 
			
		||||
                    attention_mask=causal_mask,
 | 
			
		||||
                    position_ids=position_ids,
 | 
			
		||||
                    head_mask=head_mask[i],
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
| 
						 | 
				
			
			@ -756,6 +782,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
        # past: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict:
 | 
			
		||||
        # only last token for input_ids if past is not None
 | 
			
		||||
| 
						 | 
				
			
			@ -769,9 +796,18 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
            if past_key_values[0][0].shape[0] == input_ids.shape[0]:
 | 
			
		||||
                past_key_values = self._convert_to_rw_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
       # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
 | 
			
		||||
        if not self.transformer.alibi and attention_mask is not None and position_ids is None:
 | 
			
		||||
            # create position_ids on the fly for batch generation
 | 
			
		||||
            position_ids = attention_mask.long().cumsum(-1) - 1
 | 
			
		||||
            position_ids.masked_fill_(attention_mask == 0, 1)
 | 
			
		||||
            if past_key_values:
 | 
			
		||||
                position_ids = position_ids[:, -input_ids.shape[1] :]
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "input_ids": input_ids,
 | 
			
		||||
            # "past_key_values": past,
 | 
			
		||||
            "position_ids": position_ids,
 | 
			
		||||
            "past_key_values": past_key_values,
 | 
			
		||||
            "use_cache": kwargs.get("use_cache"),
 | 
			
		||||
            "attention_mask": attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -782,6 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        labels: Optional[torch.Tensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -813,6 +850,7 @@ class RWForCausalLM(RWPreTrainedModel):
 | 
			
		|||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            head_mask=head_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue