parent
							
								
									714884414e
								
							
						
					
					
						commit
						bf51ec40b2
					
				
					 12 changed files with 22 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -42,7 +42,6 @@ if __name__ == '__main__':
 | 
			
		|||
    # which convert the relevant layers in the model into INT4 format
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 optimize_model=False,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
    model = model.to('xpu')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,7 +46,6 @@ if __name__ == '__main__':
 | 
			
		|||
    # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 optimize_model=False,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 use_cache=True)
 | 
			
		||||
    model = model.to('xpu')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,6 +70,8 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			@ -168,6 +170,8 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -82,6 +82,8 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +179,8 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,6 +105,8 @@ def bloom_attention_forward(
 | 
			
		|||
        cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -67,6 +67,8 @@ def attention_fn(
 | 
			
		|||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
			
		||||
        past_length = cache_k.size(2)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads_per_partition,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads_per_partition,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,6 +97,8 @@ def rw_attention_forward_7b(
 | 
			
		|||
        cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
			
		||||
        cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(
 | 
			
		||||
                batch_size,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -144,6 +144,8 @@ def gptj_attention_forward(
 | 
			
		|||
        past_length = cache_k.size(2)
 | 
			
		||||
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
			
		||||
                                                       self.num_attention_heads,
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,6 +90,8 @@ def gptneox_attention_forward(
 | 
			
		|||
        past_key = layer_past[0]
 | 
			
		||||
        past_value = layer_past[1]
 | 
			
		||||
        if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_past_key, new_past_value = create_kv_cache(bsz,
 | 
			
		||||
                                                           self.num_attention_heads,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -112,6 +112,8 @@ def llama_attention_forward_4_31(
 | 
			
		|||
        cache_k = past_key_value[0]
 | 
			
		||||
        cache_v = past_key_value[1]
 | 
			
		||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
			
		||||
            if device.type == 'xpu':
 | 
			
		||||
                torch.xpu.empty_cache()
 | 
			
		||||
            # allocate new
 | 
			
		||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
			
		||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,8 +19,6 @@ from bigdl.llm.utils.common import invalidInputError
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
			
		||||
    if device.type == 'xpu':
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
			
		||||
                                    max_length, head_dim,
 | 
			
		||||
                                    dtype=dtype, device=device)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue