use new rotary two in chatglm4 (#11312)
* use new rotary two in chatglm4 * rempve
This commit is contained in:
		
							parent
							
								
									f1410d6823
								
							
						
					
					
						commit
						1b0c4c8cb8
					
				
					 1 changed files with 32 additions and 65 deletions
				
			
		| 
						 | 
				
			
			@ -47,10 +47,6 @@ def chatglm4_model_forward(
 | 
			
		|||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    # if use_cache and use_quantize_kv_cache(
 | 
			
		||||
    #         self.encoder.layers[0].self_attention.query_key_value, input_ids):
 | 
			
		||||
    #     if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
    #         past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return chatglm4_model_forward_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -108,25 +104,17 @@ def chatglm4_model_forward_internal(
 | 
			
		|||
                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        position_ids = position_ids.repeat(batch_size, 1)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = input_ids.device.type == "xpu"
 | 
			
		||||
    use_fuse_rope = use_fuse_rope and not self.training
 | 
			
		||||
 | 
			
		||||
    # Rotary positional embeddings
 | 
			
		||||
    rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 | 
			
		||||
    if position_ids is not None:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb[position_ids]
 | 
			
		||||
    else:
 | 
			
		||||
        rotary_pos_emb = rotary_pos_emb[None, :seq_length]
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        # Repeat cos sin here, call only once for each token.
 | 
			
		||||
        # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
 | 
			
		||||
        # If put this to attension forward, it will generate too many times.
 | 
			
		||||
        cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
 | 
			
		||||
        cos = cos.squeeze(-1)
 | 
			
		||||
        sin = sin.squeeze(-1)
 | 
			
		||||
        cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
 | 
			
		||||
        sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
 | 
			
		||||
        rotary_pos_emb = (cos, sin)
 | 
			
		||||
    if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
 | 
			
		||||
        rot_dim = self.rotary_pos_emb.dim
 | 
			
		||||
        base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
 | 
			
		||||
        # We should generate float inv_freq to avoid overflow, as base is too large.
 | 
			
		||||
        inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
 | 
			
		||||
                                                dtype=torch.float,
 | 
			
		||||
                                                device=inputs_embeds.device) / rot_dim))
 | 
			
		||||
        self.rotary_pos_emb.register_buffer("inv_freq",
 | 
			
		||||
                                            inv_freq.to(inputs_embeds.dtype),
 | 
			
		||||
                                            persistent=False)
 | 
			
		||||
        self.rotary_pos_emb.cached = True
 | 
			
		||||
 | 
			
		||||
    # `full_attention_mask` is not None only when
 | 
			
		||||
    #  `past_key_values` is not None and `seq_length` > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -148,7 +136,7 @@ def chatglm4_model_forward_internal(
 | 
			
		|||
 | 
			
		||||
    hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
 | 
			
		||||
        inputs_embeds, causal_mask,
 | 
			
		||||
        rotary_pos_emb=rotary_pos_emb,
 | 
			
		||||
        rotary_pos_emb=(self.rotary_pos_emb.inv_freq, position_ids),
 | 
			
		||||
        kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
| 
						 | 
				
			
			@ -172,26 +160,6 @@ def chatglm4_model_forward_internal(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    # x: [b, np, sq, hn]
 | 
			
		||||
    b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
			
		||||
    rot_dim = rope_cache.shape[-2] * 2
 | 
			
		||||
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
 | 
			
		||||
    # truncate to support variable sizes
 | 
			
		||||
    rope_cache = rope_cache[:, :sq]
 | 
			
		||||
    xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
 | 
			
		||||
    rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
 | 
			
		||||
    x_out2 = torch.stack(
 | 
			
		||||
        [
 | 
			
		||||
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
 | 
			
		||||
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
 | 
			
		||||
        ],
 | 
			
		||||
        -1,
 | 
			
		||||
    )
 | 
			
		||||
    x_out2 = x_out2.flatten(3)
 | 
			
		||||
    return torch.cat((x_out2, x_pass), dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm4_attention_forward(
 | 
			
		||||
    self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -209,34 +177,33 @@ def chatglm4_attention_forward(
 | 
			
		|||
    qkv = self.query_key_value(hidden_states)
 | 
			
		||||
    # [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim]
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([n_head,
 | 
			
		||||
                                                        n_kv_head,
 | 
			
		||||
                                                        n_kv_head], dim=2)
 | 
			
		||||
                                                        n_kv_head], dim=1)
 | 
			
		||||
 | 
			
		||||
    kv_seq_len = key_states.shape[1]
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
 | 
			
		||||
        # use_fuse_rope, see chatglm4_model_forward
 | 
			
		||||
        cos, sin = rotary_pos_emb
 | 
			
		||||
        rot_dim = cos.shape[-1]
 | 
			
		||||
        query_layer_cur = query_states[..., :rot_dim]
 | 
			
		||||
        key_layer_cur = key_states[..., :rot_dim]
 | 
			
		||||
        # ipex_llm's apply_rotary_embedding can change the origin storage,
 | 
			
		||||
        # so query_layer will get the result directly.
 | 
			
		||||
        torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
 | 
			
		||||
        torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
        key_states = key_states.transpose(1, 2)
 | 
			
		||||
        value_states = value_states.transpose(1, 2)
 | 
			
		||||
    elif rotary_pos_emb is not None:
 | 
			
		||||
        query_states = query_states.transpose(1, 2)
 | 
			
		||||
        key_states = key_states.transpose(1, 2)
 | 
			
		||||
        value_states = value_states.transpose(1, 2)
 | 
			
		||||
        query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
 | 
			
		||||
        key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    inv_freq, position_ids = rotary_pos_emb
 | 
			
		||||
    rot_dim = inv_freq.size(-1) * 2
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, rotary_pos_emb[1], self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_two_inplaced(inv_freq, position_ids,
 | 
			
		||||
                                      query_states[..., :rot_dim], key_states[..., :rot_dim])
 | 
			
		||||
    else:
 | 
			
		||||
        idx_theta = torch.outer(position_ids[0].float(),
 | 
			
		||||
                                inv_freq.float()).to(hidden_states.dtype)
 | 
			
		||||
        idx_theta = idx_theta.unsqueeze(0).unsqueeze(0)
 | 
			
		||||
        cos = torch.cos(idx_theta).repeat_interleave(2, -1)
 | 
			
		||||
        sin = torch.sin(idx_theta).repeat_interleave(2, -1)
 | 
			
		||||
        q_rot, k_rot = apply_rotary_pos_emb(query_states[..., :rot_dim], key_states[..., :rot_dim],
 | 
			
		||||
                                            cos, sin, position_ids, "chatglm")
 | 
			
		||||
        query_states[..., :rot_dim] = q_rot[...]
 | 
			
		||||
        key_states[..., :rot_dim] = k_rot[...]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue