Fix cohere model on transformers>=4.41 (#11575)
* fix cohere model for 4-41
This commit is contained in:
		
							parent
							
								
									5b6eb85b85
								
							
						
					
					
						commit
						d64711900a
					
				
					 6 changed files with 151 additions and 12 deletions
				
			
		| 
						 | 
					@ -17,7 +17,7 @@ conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# install ipex-llm with 'all' option
 | 
					# install ipex-llm with 'all' option
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
					pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
On Windows:
 | 
					On Windows:
 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[all]
 | 
					pip install --pre --upgrade ipex-llm[all]
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Run
 | 
					### 2. Run
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,7 @@ conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# install the latest ipex-llm nightly build with 'all' option
 | 
					# install the latest ipex-llm nightly build with 'all' option
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
					pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
On Windows:
 | 
					On Windows:
 | 
				
			||||||
| 
						 | 
					@ -28,7 +28,7 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[all]
 | 
					pip install --pre --upgrade ipex-llm[all]
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Run
 | 
					### 2. Run
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
					conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,7 @@ conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Configures OneAPI environment variables for Linux
 | 
					### 2. Configures OneAPI environment variables for Linux
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
					conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,7 @@ conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install transformers==4.40.0
 | 
					pip install "transformers>=4.40.0"
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Configures OneAPI environment variables for Linux
 | 
					### 2. Configures OneAPI environment variables for Linux
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1382,13 +1382,23 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                        qwen2_attention_forward)
 | 
					                        qwen2_attention_forward)
 | 
				
			||||||
    elif model.config.model_type == "cohere":
 | 
					    elif model.config.model_type == "cohere":
 | 
				
			||||||
        # for CohereForAI/c4ai-command-r-v01
 | 
					        # for CohereForAI/c4ai-command-r-v01
 | 
				
			||||||
 | 
					        invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
 | 
				
			||||||
 | 
					                          "Please upgrade transformers to 4.40.0 or higher version "
 | 
				
			||||||
 | 
					                          "to run Mixtral models.")
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        if version.parse(trans_version) >= version.parse("4.41.0"):
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.CohereModel,
 | 
				
			||||||
 | 
					                            cohere_model_forward_4_41)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.cohere import cohere_model_forward
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.CohereModel,
 | 
				
			||||||
 | 
					                            cohere_model_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from ipex_llm.transformers.models.cohere import cohere_attention_forward
 | 
					        from ipex_llm.transformers.models.cohere import cohere_attention_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.cohere import cohere_model_forward
 | 
					 | 
				
			||||||
        convert_forward(model,
 | 
					 | 
				
			||||||
                        module.CohereModel,
 | 
					 | 
				
			||||||
                        cohere_model_forward)
 | 
					 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.CohereAttention,
 | 
					                        module.CohereAttention,
 | 
				
			||||||
                        cohere_attention_forward)
 | 
					                        cohere_attention_forward)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -191,6 +191,135 @@ def cohere_model_forward(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cohere_model_forward_4_41(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None \
 | 
				
			||||||
 | 
					        else self.config.use_cache
 | 
				
			||||||
 | 
					    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
				
			||||||
 | 
					        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None \
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training and use_cache:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "`use_cache=True` is incompatible "
 | 
				
			||||||
 | 
					                          "with gradient checkpointing. Setting `use_cache=False`.")
 | 
				
			||||||
 | 
					        use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_seen_tokens = 0
 | 
				
			||||||
 | 
					    return_legacy_cache = False
 | 
				
			||||||
 | 
					    # kept for BC (non `Cache` `past_key_values` inputs)
 | 
				
			||||||
 | 
					    if use_cache and not isinstance(past_key_values, Cache):
 | 
				
			||||||
 | 
					        return_legacy_cache = True
 | 
				
			||||||
 | 
					        past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cache_position is None:
 | 
				
			||||||
 | 
					        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 | 
				
			||||||
 | 
					        cache_position = torch.arange(
 | 
				
			||||||
 | 
					            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = cache_position.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    causal_mask = self._update_causal_mask(
 | 
				
			||||||
 | 
					        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					            layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                decoder_layer.__call__,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                causal_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                past_key_values,
 | 
				
			||||||
 | 
					                output_attentions,
 | 
				
			||||||
 | 
					                use_cache,
 | 
				
			||||||
 | 
					                cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # ipex-llm changes
 | 
				
			||||||
 | 
					            curr_device = decoder_layer.input_layernorm.weight.device
 | 
				
			||||||
 | 
					            if causal_mask is not None:
 | 
				
			||||||
 | 
					                causal_mask = causal_mask.to(curr_device)
 | 
				
			||||||
 | 
					            if position_ids is not None:
 | 
				
			||||||
 | 
					                position_ids = position_ids.to(curr_device)
 | 
				
			||||||
 | 
					            # ipex-llm changes end
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=causal_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					                cache_position=cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    if return_legacy_cache:
 | 
				
			||||||
 | 
					        next_cache = next_cache.to_legacy_cache()
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def cohere_attention_forward(
 | 
					def cohere_attention_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
    hidden_states: torch.Tensor,
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue