Fix cohere model on transformers>=4.41 (#11575)
* fix cohere model for 4-41
This commit is contained in:
parent
5b6eb85b85
commit
d64711900a
6 changed files with 151 additions and 12 deletions
|
|
@ -17,7 +17,7 @@ conda activate llm
|
|||
|
||||
# install ipex-llm with 'all' option
|
||||
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
|
@ -27,7 +27,7 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade ipex-llm[all]
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
### 2. Run
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ conda activate llm
|
|||
|
||||
# install the latest ipex-llm nightly build with 'all' option
|
||||
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
|
@ -28,7 +28,7 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade ipex-llm[all]
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
### 2. Run
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||
```
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ conda activate llm
|
|||
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables for Linux
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||
```
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ conda activate llm
|
|||
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.40.0
|
||||
pip install "transformers>=4.40.0"
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables for Linux
|
||||
|
|
|
|||
|
|
@ -1382,13 +1382,23 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
qwen2_attention_forward)
|
||||
elif model.config.model_type == "cohere":
|
||||
# for CohereForAI/c4ai-command-r-v01
|
||||
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
|
||||
"Please upgrade transformers to 4.40.0 or higher version "
|
||||
"to run Mixtral models.")
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
if version.parse(trans_version) >= version.parse("4.41.0"):
|
||||
from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
|
||||
convert_forward(model,
|
||||
module.CohereModel,
|
||||
cohere_model_forward_4_41)
|
||||
else:
|
||||
from ipex_llm.transformers.models.cohere import cohere_model_forward
|
||||
convert_forward(model,
|
||||
module.CohereModel,
|
||||
cohere_model_forward)
|
||||
|
||||
from ipex_llm.transformers.models.cohere import cohere_attention_forward
|
||||
from ipex_llm.transformers.models.cohere import cohere_model_forward
|
||||
convert_forward(model,
|
||||
module.CohereModel,
|
||||
cohere_model_forward)
|
||||
convert_forward(model,
|
||||
module.CohereAttention,
|
||||
cohere_attention_forward)
|
||||
|
|
|
|||
|
|
@ -191,6 +191,135 @@ def cohere_model_forward(
|
|||
)
|
||||
|
||||
|
||||
def cohere_model_forward_4_41(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None \
|
||||
else self.config.use_cache
|
||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
invalidInputError(False,
|
||||
"`use_cache=True` is incompatible "
|
||||
"with gradient checkpointing. Setting `use_cache=False`.")
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_seen_tokens = 0
|
||||
return_legacy_cache = False
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
# ipex-llm changes
|
||||
curr_device = decoder_layer.input_layernorm.weight.device
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask.to(curr_device)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(curr_device)
|
||||
# ipex-llm changes end
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache,
|
||||
all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def cohere_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue