optimize npu qwen2 (#12107)
This commit is contained in:
parent
02399021d6
commit
03bd01c99c
2 changed files with 13 additions and 9 deletions
|
|
@ -399,22 +399,22 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
self.setWeights(offset, op_id, *weights)
|
self.setWeights(offset, op_id, *weights)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_decoders(inputs, decoders):
|
def run_decoders(inputs, decoders, models_ptr=None):
|
||||||
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
|
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
|
||||||
|
|
||||||
num_decoders = len(decoders)
|
num_decoders = len(decoders)
|
||||||
num_inputs = len(x_np)
|
num_inputs = len(x_np)
|
||||||
|
|
||||||
with record_function(f"npu_factory"):
|
if models_ptr is None:
|
||||||
|
|
||||||
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
|
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
|
||||||
models_ptr = array_type(
|
models_ptr = array_type(
|
||||||
*[decoders[i]._mm for i in range(num_decoders)]
|
*[decoders[i]._mm for i in range(num_decoders)]
|
||||||
)
|
)
|
||||||
inputs_ptr = (ctypes.c_void_p * num_inputs)(
|
|
||||||
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
|
inputs_ptr = (ctypes.c_void_p * num_inputs)(
|
||||||
)
|
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
|
||||||
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
|
)
|
||||||
|
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
|
||||||
|
|
||||||
hidden_states = decoders[-1].torch_out[0]
|
hidden_states = decoders[-1].torch_out[0]
|
||||||
new_key_states = []
|
new_key_states = []
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
import ctypes
|
||||||
from typing import Optional, Sequence, List, Union, Any, Tuple
|
from typing import Optional, Sequence, List, Union, Any, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -379,6 +379,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
|
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
|
||||||
offset = offset + curr_linear_ops
|
offset = offset + curr_linear_ops
|
||||||
|
|
||||||
|
array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
|
||||||
|
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -402,7 +405,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
|
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
|
||||||
inputs,
|
inputs,
|
||||||
decoders=self.backend_decoders)
|
self.backend_decoders,
|
||||||
|
self.models_ptr)
|
||||||
|
|
||||||
if self.do_print:
|
if self.do_print:
|
||||||
print("outputs:", hidden_states)
|
print("outputs:", hidden_states)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue