From 380717f50d732f1ebf56aa2bdbe6e6eb5607c805 Mon Sep 17 00:00:00 2001 From: Guoqiong Song Date: Thu, 18 Jul 2024 15:02:50 -0700 Subject: [PATCH] fix gemma for 4.41 (#11531) * fix gemma for 4.41 --- .../Model/codegemma/README.md | 4 +- .../Model/gemma/README.md | 4 +- .../PyTorch-Models/Model/codegemma/README.md | 4 +- .../GPU/HuggingFace/LLM/codegemma/README.md | 4 +- .../GPU/HuggingFace/LLM/gemma/README.md | 4 +- .../PyTorch-Models/Model/codegemma/README.md | 4 +- .../llm/src/ipex_llm/transformers/convert.py | 21 ++- .../src/ipex_llm/transformers/models/gemma.py | 153 ++++++++++++++++++ 8 files changed, 181 insertions(+), 17 deletions(-) diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/codegemma/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/codegemma/README.md index 500e4027..000aa348 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/codegemma/README.md +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/codegemma/README.md @@ -21,7 +21,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` On Windows: @@ -32,7 +32,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[all] -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Run diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/gemma/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/gemma/README.md index bb09f15e..93dd2047 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Model/gemma/README.md +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Model/gemma/README.md @@ -22,7 +22,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu # According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` On Windows: @@ -33,7 +33,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[all] -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Run diff --git a/python/llm/example/CPU/PyTorch-Models/Model/codegemma/README.md b/python/llm/example/CPU/PyTorch-Models/Model/codegemma/README.md index 81831376..543b71b7 100644 --- a/python/llm/example/CPU/PyTorch-Models/Model/codegemma/README.md +++ b/python/llm/example/CPU/PyTorch-Models/Model/codegemma/README.md @@ -21,7 +21,7 @@ conda activate llm # install the latest ipex-llm nightly build with 'all' option pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` On Windows: @@ -31,7 +31,7 @@ conda create -n llm python=3.11 conda activate llm pip install --pre --upgrade ipex-llm[all] -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Run diff --git a/python/llm/example/GPU/HuggingFace/LLM/codegemma/README.md b/python/llm/example/GPU/HuggingFace/LLM/codegemma/README.md index 96a0b804..d0f9f5e0 100644 --- a/python/llm/example/GPU/HuggingFace/LLM/codegemma/README.md +++ b/python/llm/example/GPU/HuggingFace/LLM/codegemma/README.md @@ -20,7 +20,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` #### 1.2 Installation on Windows @@ -33,7 +33,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Configures OneAPI environment variables for Linux diff --git a/python/llm/example/GPU/HuggingFace/LLM/gemma/README.md b/python/llm/example/GPU/HuggingFace/LLM/gemma/README.md index 14aa69db..eff2f97a 100644 --- a/python/llm/example/GPU/HuggingFace/LLM/gemma/README.md +++ b/python/llm/example/GPU/HuggingFace/LLM/gemma/README.md @@ -18,7 +18,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` #### 1.2 Installation on Windows @@ -31,7 +31,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Configures OneAPI environment variables for Linux diff --git a/python/llm/example/GPU/PyTorch-Models/Model/codegemma/README.md b/python/llm/example/GPU/PyTorch-Models/Model/codegemma/README.md index 1c145def..33da3966 100644 --- a/python/llm/example/GPU/PyTorch-Models/Model/codegemma/README.md +++ b/python/llm/example/GPU/PyTorch-Models/Model/codegemma/README.md @@ -20,7 +20,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` #### 1.2 Installation on Windows @@ -33,7 +33,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer. -pip install transformers==4.38.1 +pip install "transformers>=4.38.1" ``` ### 2. Configures OneAPI environment variables for Linux diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 1b0f0972..9fcb49db 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1481,21 +1481,32 @@ def _optimize_post(model, lightweight_bmm=False): module.MistralMLP, llama_mlp_forward) elif model.config.model_type == "gemma": + invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"), + "Please upgrade transformers to 4.38.0 or higher version " + "to run Mixtral models.") modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.gemma import gemma_attention_forward + if version.parse(trans_version) >= version.parse("4.39.0"): + from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39 + convert_forward(model, + module.GemmaAttention, + gemma_attention_forward_4_39 + ) + else: + from ipex_llm.transformers.models.gemma import gemma_attention_forward + convert_forward(model, + module.GemmaAttention, + gemma_attention_forward, + ) from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma import gemma_mlp_forward - convert_forward(model, - module.GemmaAttention, - gemma_attention_forward, - ) convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward) convert_forward(model, module.GemmaMLP, gemma_mlp_forward) + elif model.config.model_type == "gemma2": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index 71c8dfc5..f542a429 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -267,3 +267,156 @@ def gemma_attention_forward( attn_weights = None return attn_output.to(original_dtype), attn_weights, past_key_value + + +def gemma_attention_forward_4_39( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor]=None, + position_ids: Optional[torch.LongTensor]=None, + past_key_value: Optional[Tuple[torch.Tensor]]=None, + output_attentions: bool=False, + use_cache: bool=False, + cache_position: Optional[torch.Tensor]=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, hidden_size = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len) + + if decoding_fast_path: + hidden_states = hidden_states.view(1, -1) + + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + kv_seq_len = cache_k.shape[-2] + + import xe_linear + query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + position_ids, + cache_k, cache_v, + self.q_proj.weight.qtype, + self.v_proj.weight.qtype, + kv_seq_len, + self.head_dim) + kv_seq_len += 1 + + # update past_key_value's seem_tokens and kv caches. + if self.layer_idx == 0: + past_key_value._seen_tokens = kv_seq_len + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + if past_key_value is not None: + if self.layer_idx is None: + invalidInputError(False, + "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} for " + "auto-regressive decodingwith k/v caching, please make sure " + "to initialize the attention class with a layer index.") + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if use_fuse_rope: + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "gemma") + else: + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, None) + + if past_key_value is not None: + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value._seen_tokens += key_states.shape[-2] + + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] + + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) + + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v + + key_states, value_states = append_kv_cache(cache_k, cache_v, + key_states, value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + if cache_position is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + else: + causal_mask = attention_mask + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output.to(original_dtype), attn_weights, past_key_value