Codegeex support (#12303)
* new codegeex attn * use kv cache * add compress/quantize kv * remove compress/quantize kv * fix style check * fix style * fix codegeex
This commit is contained in:
		
							parent
							
								
									72605c7016
								
							
						
					
					
						commit
						97a0f7fd35
					
				
					 2 changed files with 233 additions and 1 deletions
				
			
		| 
						 | 
					@ -1364,7 +1364,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]
 | 
					        and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if hasattr(model.config, 'padded_vocab_size') and \
 | 
					        if hasattr(model.config, 'padded_vocab_size') and \
 | 
				
			||||||
                model.config.padded_vocab_size in [65024, 64896]:
 | 
					                model.config.padded_vocab_size == 65024:
 | 
				
			||||||
            # chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
 | 
					            # chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
 | 
				
			||||||
            modeling_module_name = model.__class__.__module__
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
            module = importlib.import_module(modeling_module_name)
 | 
					            module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					@ -1384,6 +1384,27 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.RMSNorm,
 | 
					                            module.RMSNorm,
 | 
				
			||||||
                            chatglm_rms_norm_forward)
 | 
					                            chatglm_rms_norm_forward)
 | 
				
			||||||
 | 
					        elif hasattr(model.config, 'padded_vocab_size') and \
 | 
				
			||||||
 | 
					                model.config.padded_vocab_size == 64896:
 | 
				
			||||||
 | 
					            # codegeex-nano
 | 
				
			||||||
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					            module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.chatglm2 import codegeex_attention_forward
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.chatglm2 import codegeex_model_forward
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.SelfAttention,
 | 
				
			||||||
 | 
					                            codegeex_attention_forward)
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.GLMTransformer,
 | 
				
			||||||
 | 
					                            chatglm2_encoder_forward)
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.ChatGLMModel,
 | 
				
			||||||
 | 
					                            codegeex_model_forward)
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.RMSNorm,
 | 
				
			||||||
 | 
					                            chatglm_rms_norm_forward)
 | 
				
			||||||
        elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
 | 
					        elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
 | 
				
			||||||
            # chatglm-6b
 | 
					            # chatglm-6b
 | 
				
			||||||
            modeling_module_name = model.__class__.__module__
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -359,3 +359,214 @@ def chatglm2_attention_forward(
 | 
				
			||||||
    output = self.dense(attn_output)
 | 
					    output = self.dense(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return output, past_key_value
 | 
					    return output, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.jit.script
 | 
				
			||||||
 | 
					def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					    # x: [sq, b, np, hn]
 | 
				
			||||||
 | 
					    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
				
			||||||
 | 
					    rot_dim = rope_cache.shape[-2] * 2
 | 
				
			||||||
 | 
					    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
 | 
				
			||||||
 | 
					    # truncate to support variable sizes
 | 
				
			||||||
 | 
					    rope_cache = rope_cache[:sq]
 | 
				
			||||||
 | 
					    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
 | 
				
			||||||
 | 
					    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
 | 
				
			||||||
 | 
					    x_out2 = torch.stack(
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
 | 
				
			||||||
 | 
					            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        -1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    x_out2 = x_out2.flatten(3)
 | 
				
			||||||
 | 
					    return torch.cat((x_out2, x_pass), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def codegeex_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.Tensor]=None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.BoolTensor]=None,
 | 
				
			||||||
 | 
					    full_attention_mask: Optional[torch.BoolTensor]=None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.Tensor]=None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool]=None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool]=None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool]=None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        batch_size, seq_length = input_ids.shape
 | 
				
			||||||
 | 
					        inputs_embeds = self.embedding(input_ids)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
 | 
				
			||||||
 | 
					        seq_length, batch_size, _ = inputs_embeds.shape
 | 
				
			||||||
 | 
					        input_ids = torch.empty((batch_size, seq_length),
 | 
				
			||||||
 | 
					                                dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if full_attention_mask is None:
 | 
				
			||||||
 | 
					        if (attention_mask is not None and not attention_mask.all()) or (
 | 
				
			||||||
 | 
					                past_key_values and seq_length != 1):
 | 
				
			||||||
 | 
					            full_attention_mask = self.get_masks(input_ids,
 | 
				
			||||||
 | 
					                                                 past_key_values,
 | 
				
			||||||
 | 
					                                                 padding_mask=attention_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes begin
 | 
				
			||||||
 | 
					    # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
 | 
				
			||||||
 | 
					    # 2. generate `causal_mask` and replace `full_attention_mask` with it
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        if past_key_values is None:
 | 
				
			||||||
 | 
					            position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if isinstance(past_key_values, DynamicCompressCache):
 | 
				
			||||||
 | 
					                kv_length = past_key_values.get_seq_length()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                kv_length = past_key_values[0][0].size(0)
 | 
				
			||||||
 | 
					            position_ids = torch.arange(kv_length, kv_length + seq_length,
 | 
				
			||||||
 | 
					                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
				
			||||||
 | 
					        position_ids = position_ids.repeat(batch_size, 1)
 | 
				
			||||||
 | 
					    use_fuse_rope = input_ids.device.type == "xpu" and not self.training
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Rotary positional embeddings
 | 
				
			||||||
 | 
					    rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 | 
				
			||||||
 | 
					    if position_ids is not None:
 | 
				
			||||||
 | 
					        rotary_pos_emb = rotary_pos_emb[position_ids]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        rotary_pos_emb = rotary_pos_emb[None, :seq_length]
 | 
				
			||||||
 | 
					    if use_fuse_rope:
 | 
				
			||||||
 | 
					        # Repeat cos sin here, call only once for each token.
 | 
				
			||||||
 | 
					        # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
 | 
				
			||||||
 | 
					        # If put this to attension forward, it will generate too many times.
 | 
				
			||||||
 | 
					        cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
 | 
				
			||||||
 | 
					        cos = cos.squeeze(-1)
 | 
				
			||||||
 | 
					        sin = sin.squeeze(-1)
 | 
				
			||||||
 | 
					        cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
 | 
				
			||||||
 | 
					        sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
 | 
				
			||||||
 | 
					        rotary_pos_emb = (cos, sin)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # `full_attention_mask` is not None only when
 | 
				
			||||||
 | 
					    #  `past_key_values` is not None and `seq_length` > 1
 | 
				
			||||||
 | 
					    if full_attention_mask is not None:
 | 
				
			||||||
 | 
					        causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
 | 
				
			||||||
 | 
					                                  dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
				
			||||||
 | 
					        mask_value = torch.finfo(inputs_embeds.dtype).min
 | 
				
			||||||
 | 
					        causal_mask.masked_fill_(full_attention_mask, mask_value)
 | 
				
			||||||
 | 
					    elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
 | 
				
			||||||
 | 
					        full_attention_mask = self.get_masks(input_ids,
 | 
				
			||||||
 | 
					                                             past_key_values,
 | 
				
			||||||
 | 
					                                             padding_mask=attention_mask)
 | 
				
			||||||
 | 
					        causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
 | 
				
			||||||
 | 
					                                  dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
				
			||||||
 | 
					        mask_value = torch.finfo(inputs_embeds.dtype).min
 | 
				
			||||||
 | 
					        causal_mask.masked_fill_(full_attention_mask, mask_value)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        causal_mask = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Run encoder.
 | 
				
			||||||
 | 
					    hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
 | 
				
			||||||
 | 
					        inputs_embeds, causal_mask,
 | 
				
			||||||
 | 
					        rotary_pos_emb=rotary_pos_emb,
 | 
				
			||||||
 | 
					        kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
 | 
				
			||||||
 | 
					                     if v is not None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=presents,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attentions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def codegeex_attention_forward(
 | 
				
			||||||
 | 
					    self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    q_len, bsz, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    n_head = self.num_attention_heads_per_partition
 | 
				
			||||||
 | 
					    n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
 | 
				
			||||||
 | 
					    head_dim = self.hidden_size_per_attention_head
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
 | 
				
			||||||
 | 
					                                                    kv_cache[1].permute(1, 2, 0, 3))
 | 
				
			||||||
 | 
					    qkv = self.query_key_value(hidden_states)
 | 
				
			||||||
 | 
					    qkv = qkv.view(q_len, bsz, n_head + 2 * n_kv_head, head_dim)
 | 
				
			||||||
 | 
					    # [seq_len, bsz, n_head, head_dim] -> [bsz, n_head, seq_len, head_dim]
 | 
				
			||||||
 | 
					    qkv = qkv.permute(1, 2, 0, 3)
 | 
				
			||||||
 | 
					    query_layer, key_layer, value_layer = qkv.split([n_head,
 | 
				
			||||||
 | 
					                                                     n_kv_head,
 | 
				
			||||||
 | 
					                                                     n_kv_head], dim=1)
 | 
				
			||||||
 | 
					    kv_seq_len = key_layer.shape[2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += past_key_value[0].shape[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # apply relative positional encoding (rotary embedding)
 | 
				
			||||||
 | 
					    if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
 | 
				
			||||||
 | 
					        cos, sin = rotary_pos_emb
 | 
				
			||||||
 | 
					        rot_dim = cos.shape[-1]
 | 
				
			||||||
 | 
					        query_layer = query_layer.transpose(1, 2)
 | 
				
			||||||
 | 
					        key_layer = key_layer.transpose(1, 2)
 | 
				
			||||||
 | 
					        query_layer_cur = query_layer[..., :rot_dim]
 | 
				
			||||||
 | 
					        key_layer_cur = key_layer[..., :rot_dim]
 | 
				
			||||||
 | 
					        # ipex_llm's apply_rotary_embedding can change the origin storage,
 | 
				
			||||||
 | 
					        # so query_layer will get the result directly.
 | 
				
			||||||
 | 
					        torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
 | 
				
			||||||
 | 
					        torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
 | 
				
			||||||
 | 
					        query_layer = query_layer.transpose(1, 2)
 | 
				
			||||||
 | 
					        key_layer = key_layer.transpose(1, 2)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb)
 | 
				
			||||||
 | 
					        key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    key_layer, value_layer = update_past_key_value(
 | 
				
			||||||
 | 
					        past_key_value, key_layer, value_layer,
 | 
				
			||||||
 | 
					        kv_seq_len, False, hidden_states.device
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
 | 
				
			||||||
 | 
					    past_key_value = (key_layer.permute(2, 0, 1, 3),
 | 
				
			||||||
 | 
					                      value_layer.permute(2, 0, 1, 3)) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # =================
 | 
				
			||||||
 | 
					    # Output. [sq, b, h]
 | 
				
			||||||
 | 
					    # =================
 | 
				
			||||||
 | 
					    context_layer = None
 | 
				
			||||||
 | 
					    if use_sdp(q_len, kv_seq_len, head_dim, query_layer):
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        context_layer = xe_addons.sdp(query_layer, key_layer, value_layer, attention_mask)
 | 
				
			||||||
 | 
					    elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_layer, self.training):
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        context_layer = xe_addons.sdp_causal(query_layer, key_layer, value_layer, attention_mask)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
 | 
					        key_layer = repeat_kv(key_layer, n_head // n_kv_head)
 | 
				
			||||||
 | 
					        value_layer = repeat_kv(value_layer, n_head // n_kv_head)
 | 
				
			||||||
 | 
					        if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
 | 
				
			||||||
 | 
					            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
				
			||||||
 | 
					                                                                             key_layer,
 | 
				
			||||||
 | 
					                                                                             value_layer,
 | 
				
			||||||
 | 
					                                                                             is_causal=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = ~attention_mask
 | 
				
			||||||
 | 
					            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
				
			||||||
 | 
					                                                                             key_layer,
 | 
				
			||||||
 | 
					                                                                             value_layer,
 | 
				
			||||||
 | 
					                                                                             attention_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len,
 | 
				
			||||||
 | 
					                                                                        bsz,
 | 
				
			||||||
 | 
					                                                                        n_head * head_dim)
 | 
				
			||||||
 | 
					    output = self.dense(context_layer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return output, past_key_value
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue