parent
714884414e
commit
bf51ec40b2
12 changed files with 22 additions and 4 deletions
|
|
@ -42,7 +42,6 @@ if __name__ == '__main__':
|
||||||
# which convert the relevant layers in the model into INT4 format
|
# which convert the relevant layers in the model into INT4 format
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
optimize_model=False,
|
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ if __name__ == '__main__':
|
||||||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
optimize_model=False,
|
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,8 @@ def baichuan_attention_forward_7b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|
@ -168,6 +170,8 @@ def baichuan_attention_forward_13b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,8 @@ def baichuan_attention_forward_7b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|
@ -177,6 +179,8 @@ def baichuan_attention_forward_13b(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,8 @@ def bloom_attention_forward(
|
||||||
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_k = layer_past[0].transpose(1, 2).view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = create_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,8 @@ def attention_fn(
|
||||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,8 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,8 @@ def rw_attention_forward_7b(
|
||||||
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||||
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(
|
new_cache_k, new_cache_v = create_kv_cache(
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,8 @@ def gptj_attention_forward(
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
|
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
new_cache_k, new_cache_v = create_kv_cache(batch_size,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,8 @@ def gptneox_attention_forward(
|
||||||
past_key = layer_past[0]
|
past_key = layer_past[0]
|
||||||
past_value = layer_past[1]
|
past_value = layer_past[1]
|
||||||
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_past_key, new_past_value = create_kv_cache(bsz,
|
new_past_key, new_past_value = create_kv_cache(bsz,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,8 @@ def llama_attention_forward_4_31(
|
||||||
cache_k = past_key_value[0]
|
cache_k = past_key_value[0]
|
||||||
cache_v = past_key_value[1]
|
cache_v = past_key_value[1]
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
||||||
|
if device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
new_cache_k, new_cache_v = create_kv_cache(bsz,
|
||||||
self.num_key_value_heads, # Support GQA
|
self.num_key_value_heads, # Support GQA
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,6 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
if device.type == 'xpu':
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
max_length, head_dim,
|
max_length, head_dim,
|
||||||
dtype=dtype, device=device)
|
dtype=dtype, device=device)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue