parent
714884414e
commit
bf51ec40b2
12 changed files with 22 additions and 4 deletions
|
|
@ -42,7 +42,6 @@ if __name__ == '__main__':
|
|||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ if __name__ == '__main__':
|
|||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True,
|
||||
use_cache=True)
|
||||
model = model.to('xpu')
|
||||
|
|
|
|||
|
|
@ -70,6 +70,8 @@ def baichuan_attention_forward_7b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
@ -168,6 +170,8 @@ def baichuan_attention_forward_13b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
|
|||
|
|
@ -82,6 +82,8 @@ def baichuan_attention_forward_7b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
@ -177,6 +179,8 @@ def baichuan_attention_forward_13b(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
self.num_heads,
|
||||
|
|
|
|||
|
|
@ -105,6 +105,8 @@ def bloom_attention_forward(
|
|||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
batch_size,
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ def attention_fn(
|
|||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
past_length = cache_k.size(2)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
|
|
|
|||
|
|
@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
|
|||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
self.num_attention_heads_per_partition,
|
||||
|
|
|
|||
|
|
@ -97,6 +97,8 @@ def rw_attention_forward_7b(
|
|||
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(
|
||||
batch_size,
|
||||
|
|
|
|||
|
|
@ -144,6 +144,8 @@ def gptj_attention_forward(
|
|||
past_length = cache_k.size(2)
|
||||
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||
self.num_attention_heads,
|
||||
self.head_dim,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ def gptneox_attention_forward(
|
|||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_past_key, new_past_value = create_kv_cache(bsz,
|
||||
self.num_attention_heads,
|
||||
|
|
|
|||
|
|
@ -112,6 +112,8 @@ def llama_attention_forward_4_31(
|
|||
cache_k = past_key_value[0]
|
||||
cache_v = past_key_value[1]
|
||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
# allocate new
|
||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
|
||||
|
||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
if device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
dtype=dtype, device=device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue