LLM: integrate sdp kernel for FP16 rest token inference on GPU [DG2/ATSM] (#9633)

* integrate sdp

* update api

* fix style

* meet code review

* fix

* distinguish mtl from arc

* small fix
This commit is contained in:
Ruonan Wang 2023-12-13 11:29:57 +08:00 committed by GitHub
parent 5b0e7e308c
commit dc5b1d7e9d
2 changed files with 39 additions and 2 deletions

View file

@ -249,8 +249,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
# convert here
m, n = module.weight.data.shape
trans_weight = module.weight.data.reshape(m//16, 16, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
if module.in_features == 11008:
trans_weight = module.weight.data.reshape(m//8, 8, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
elif module.in_features == 4096:
trans_weight = module.weight.data.reshape(m//16, 16, n)
trans_weight = trans_weight.transpose(1, 2).contiguous()
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\

View file

@ -219,6 +219,13 @@ def llama_attention_forward_4_31(
value_states,
is_causal=True)
attn_weights = None
elif use_esimd_sdp(q_len, self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states.contiguous(),
value_states.contiguous())
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
@ -266,6 +273,32 @@ def check_flash_attention_available(query):
return True
def use_esimd_sdp(q_len, head_dim, query_states):
if head_dim != 128:
# esimd_sdp only support head_dim = 128 now
return False
elif q_len != 1:
# esimd_sdp only support rest token now
return False
elif query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
else:
device_name = torch.xpu.get_device_name(query_states.device.index)
if device_name.startswith("Intel(R) Arc(TM) A") or \
device_name.startswith("Intel(R) Data Center GPU Flex"):
import linear_fp16_esimd
if hasattr(linear_fp16_esimd, "sdp_forward"):
return True
else:
return False
else:
return False
def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads):
attn_weights = torch.matmul(query,