LLM: add esimd sdp for pvc (#10543)
* add esimd sdp for pvc * update * fix * fix batch
This commit is contained in:
parent
817ef2d1de
commit
ea4bc450c4
3 changed files with 20 additions and 14 deletions
|
|
@ -121,7 +121,7 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
|
||||||
low_bit,
|
low_bit,
|
||||||
cpu_embedding if 'win' in test_api else 'N/A',
|
cpu_embedding if 'win' in test_api else 'N/A',
|
||||||
round(result[in_out_pair][-1][5], 2),
|
round(result[in_out_pair][-1][5], 2),
|
||||||
result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu' ]) else 'N/A',
|
result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu', 'fp16_gpu']) else 'N/A',
|
||||||
streaming if 'win' in test_api else 'N/A'],
|
streaming if 'win' in test_api else 'N/A'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -716,7 +716,7 @@ def run_bigdl_fp16_gpu(repo_id,
|
||||||
print(output[0])
|
print(output[0])
|
||||||
if i >= warm_up:
|
if i >= warm_up:
|
||||||
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
|
result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time,
|
||||||
actual_in_len, actual_out_len, load_time])
|
actual_in_len, actual_out_len, load_time, model.peak_memory])
|
||||||
del model
|
del model
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -626,7 +626,7 @@ def llama_attention_forward_4_31_original(
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
|
||||||
import linear_fp16_esimd
|
import linear_fp16_esimd
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,7 @@ def use_flash_attention(query, key, attention_mask=None):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states):
|
def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
||||||
if head_dim != 128:
|
if head_dim != 128:
|
||||||
# esimd_sdp only support head_dim = 128 now
|
# esimd_sdp only support head_dim = 128 now
|
||||||
return False
|
return False
|
||||||
|
|
@ -317,17 +317,23 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states):
|
||||||
elif query_states.dtype != torch.float16:
|
elif query_states.dtype != torch.float16:
|
||||||
# esimd_sdp only has optimization for FP16 now
|
# esimd_sdp only has optimization for FP16 now
|
||||||
return False
|
return False
|
||||||
else:
|
elif query_states.shape[0] > 1 and attention_mask is not None:
|
||||||
|
# for batched input, can't accept attention_mask
|
||||||
|
# TODO: this check needs some time
|
||||||
|
if not torch.all(attention_mask.eq(0)):
|
||||||
|
return False
|
||||||
|
|
||||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||||
device_name.startswith("Intel(R) Data Center GPU Flex"):
|
device_name.startswith("Intel(R) Data Center GPU Flex") or \
|
||||||
|
device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||||
import linear_fp16_esimd
|
import linear_fp16_esimd
|
||||||
if hasattr(linear_fp16_esimd, "sdp_forward"):
|
if not hasattr(linear_fp16_esimd, "sdp_forward"):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def mlp_fusion_check(x, qtype, training):
|
def mlp_fusion_check(x, qtype, training):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue