LLM: integrate sdp kernel for FP16 rest token inference on GPU [DG2/ATSM] (#9633)
* integrate sdp * update api * fix style * meet code review * fix * distinguish mtl from arc * small fix
This commit is contained in:
parent
5b0e7e308c
commit
dc5b1d7e9d
2 changed files with 39 additions and 2 deletions
|
|
@ -249,8 +249,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
|
||||
# convert here
|
||||
m, n = module.weight.data.shape
|
||||
trans_weight = module.weight.data.reshape(m//16, 16, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
if module.in_features == 11008:
|
||||
trans_weight = module.weight.data.reshape(m//8, 8, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
elif module.in_features == 4096:
|
||||
trans_weight = module.weight.data.reshape(m//16, 16, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
|
|
|
|||
|
|
@ -219,6 +219,13 @@ def llama_attention_forward_4_31(
|
|||
value_states,
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous())
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
# otherwise, use native attention
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
|
|
@ -266,6 +273,32 @@ def check_flash_attention_available(query):
|
|||
return True
|
||||
|
||||
|
||||
def use_esimd_sdp(q_len, head_dim, query_states):
|
||||
if head_dim != 128:
|
||||
# esimd_sdp only support head_dim = 128 now
|
||||
return False
|
||||
elif q_len != 1:
|
||||
# esimd_sdp only support rest token now
|
||||
return False
|
||||
elif query_states.device.type != "xpu":
|
||||
# esimd_sdp only support GPU now
|
||||
return False
|
||||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
else:
|
||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||
device_name.startswith("Intel(R) Data Center GPU Flex"):
|
||||
import linear_fp16_esimd
|
||||
if hasattr(linear_fp16_esimd, "sdp_forward"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def native_sdp(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||
attn_weights = torch.matmul(query,
|
||||
|
|
|
|||
Loading…
Reference in a new issue