Fix cohere model on transformers>=4.41 (#11575)

* fix cohere model for 4-41
This commit is contained in:
Guoqiong Song 2024-07-17 17:18:59 -07:00 committed by GitHub
parent 5b6eb85b85
commit d64711900a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 151 additions and 12 deletions

View file

@ -17,7 +17,7 @@ conda activate llm
# install ipex-llm with 'all' option
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
On Windows:
@ -27,7 +27,7 @@ conda create -n llm python=3.11
conda activate llm
pip install --pre --upgrade ipex-llm[all]
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
### 2. Run

View file

@ -18,7 +18,7 @@ conda activate llm
# install the latest ipex-llm nightly build with 'all' option
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
On Windows:
@ -28,7 +28,7 @@ conda create -n llm python=3.11
conda activate llm
pip install --pre --upgrade ipex-llm[all]
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
### 2. Run

View file

@ -17,7 +17,7 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
```
@ -29,7 +29,7 @@ conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
### 2. Configures OneAPI environment variables for Linux

View file

@ -17,7 +17,7 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
```
@ -29,7 +29,7 @@ conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.40.0
pip install "transformers>=4.40.0"
```
### 2. Configures OneAPI environment variables for Linux

View file

@ -1382,13 +1382,23 @@ def _optimize_post(model, lightweight_bmm=False):
qwen2_attention_forward)
elif model.config.model_type == "cohere":
# for CohereForAI/c4ai-command-r-v01
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
"Please upgrade transformers to 4.40.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.41.0"):
from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
convert_forward(model,
module.CohereModel,
cohere_model_forward_4_41)
else:
from ipex_llm.transformers.models.cohere import cohere_model_forward
convert_forward(model,
module.CohereModel,
cohere_model_forward)
from ipex_llm.transformers.models.cohere import cohere_attention_forward
from ipex_llm.transformers.models.cohere import cohere_model_forward
convert_forward(model,
module.CohereModel,
cohere_model_forward)
convert_forward(model,
module.CohereAttention,
cohere_attention_forward)

View file

@ -191,6 +191,135 @@ def cohere_model_forward(
)
def cohere_model_forward_4_41(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
use_cache = use_cache if use_cache is not None \
else self.config.use_cache
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False,
"You cannot specify both input_ids and inputs_embeds at the same time")
if self.gradient_checkpointing and self.training and use_cache:
invalidInputError(False,
"`use_cache=True` is incompatible "
"with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
return_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
# ipex-llm changes
curr_device = decoder_layer.input_layernorm.weight.device
if causal_mask is not None:
causal_mask = causal_mask.to(curr_device)
if position_ids is not None:
position_ids = position_ids.to(curr_device)
# ipex-llm changes end
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache,
all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def cohere_attention_forward(
self,
hidden_states: torch.Tensor,