LLM: optimize chatglm2 8k input. (#10723)
* LLM: optimize chatglm2 8k input. * rename.
This commit is contained in:
		
							parent
							
								
									cd22cb8257
								
							
						
					
					
						commit
						4b024b7aac
					
				
					 1 changed files with 36 additions and 11 deletions
				
			
		| 
						 | 
					@ -252,10 +252,31 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            key, value = key_layer, value_layer
 | 
					            key, value = key_layer, value_layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if attention_mask is None:
 | 
					        # split tensor for memory block limitation
 | 
				
			||||||
            context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True)
 | 
					        # support fp16 and set input length threshold at 5000 for now
 | 
				
			||||||
 | 
					        if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000:
 | 
				
			||||||
 | 
					            # split second dim to block size = 8
 | 
				
			||||||
 | 
					            block_size = 8
 | 
				
			||||||
 | 
					            query_split = torch.split(query_layer, block_size, dim=1)
 | 
				
			||||||
 | 
					            key_split = torch.split(key, block_size, dim=1)
 | 
				
			||||||
 | 
					            value_split = torch.split(value, block_size, dim=1)
 | 
				
			||||||
 | 
					            context_layer = torch.empty(batch_size, n_head,
 | 
				
			||||||
 | 
					                                        seq_len, head_dim).to(query_layer.device)
 | 
				
			||||||
 | 
					            idx = 0
 | 
				
			||||||
 | 
					            for q, k, v in zip(query_split, key_split, value_split):
 | 
				
			||||||
 | 
					                if attention_mask is None:
 | 
				
			||||||
 | 
					                    result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    result = F.scaled_dot_product_attention(q, k, v, attention_mask)
 | 
				
			||||||
 | 
					                context_layer[:, idx:idx+q.shape[1], :, :] = result
 | 
				
			||||||
 | 
					                idx = idx + q.shape[1]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask)
 | 
					            if attention_mask is None:
 | 
				
			||||||
 | 
					                context_layer = F.scaled_dot_product_attention(query_layer, key,
 | 
				
			||||||
 | 
					                                                               value, is_causal=True)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                context_layer = F.scaled_dot_product_attention(query_layer, key,
 | 
				
			||||||
 | 
					                                                               value, attention_mask)
 | 
				
			||||||
        context_layer = context_layer.to(query_layer.dtype)
 | 
					        context_layer = context_layer.to(query_layer.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
| 
						 | 
					@ -517,15 +538,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
 | 
				
			||||||
            # split tensor for memory block limitation
 | 
					            # split tensor for memory block limitation
 | 
				
			||||||
            # support fp16 and set input length threshold at 5000 for now
 | 
					            # support fp16 and set input length threshold at 5000 for now
 | 
				
			||||||
            if query_layer.dtype == torch.float16 and L >= 5000:
 | 
					            if query_layer.dtype == torch.float16 and L >= 5000:
 | 
				
			||||||
                # split first dim 32 -> 8
 | 
					                # split second dim to block size = 8
 | 
				
			||||||
                query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1)
 | 
					                block_size = 8
 | 
				
			||||||
                key_sp = torch.split(key_layer, 8, dim=1)
 | 
					                query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
 | 
				
			||||||
                value_sp = torch.split(value_layer, 8, dim=1)
 | 
					                key_split = torch.split(key_layer, block_size, dim=1)
 | 
				
			||||||
                results = []
 | 
					                value_split = torch.split(value_layer, block_size, dim=1)
 | 
				
			||||||
                for q, k, v in zip(query_sp, key_sp, value_sp):
 | 
					                batch_size, n_head, seq_len, head_dim = query_layer.shape
 | 
				
			||||||
 | 
					                context_layer = torch.empty(batch_size, n_head, seq_len,
 | 
				
			||||||
 | 
					                                            head_dim).to(query_layer.device).to(key_layer.dtype)
 | 
				
			||||||
 | 
					                idx = 0
 | 
				
			||||||
 | 
					                for q, k, v in zip(query_split, key_split, value_split):
 | 
				
			||||||
                    result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
 | 
					                    result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
 | 
				
			||||||
                    results.append(result)
 | 
					                    context_layer[:, idx:idx+q.shape[1], :, :] = result
 | 
				
			||||||
                context_layer = torch.cat(results, dim=1)
 | 
					                    idx = idx + q.shape[1]
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
					                context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
				
			||||||
                                                               key_layer,
 | 
					                                                               key_layer,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue