optimize npu qwen2 (#12107)
This commit is contained in:
parent
02399021d6
commit
03bd01c99c
2 changed files with 13 additions and 9 deletions
|
|
@ -399,22 +399,22 @@ class LLMBaseNNFactory(NNFactory):
|
|||
self.setWeights(offset, op_id, *weights)
|
||||
|
||||
@staticmethod
|
||||
def run_decoders(inputs, decoders):
|
||||
def run_decoders(inputs, decoders, models_ptr=None):
|
||||
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
|
||||
|
||||
num_decoders = len(decoders)
|
||||
num_inputs = len(x_np)
|
||||
|
||||
with record_function(f"npu_factory"):
|
||||
|
||||
if models_ptr is None:
|
||||
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
|
||||
models_ptr = array_type(
|
||||
*[decoders[i]._mm for i in range(num_decoders)]
|
||||
)
|
||||
inputs_ptr = (ctypes.c_void_p * num_inputs)(
|
||||
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
|
||||
)
|
||||
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
|
||||
|
||||
inputs_ptr = (ctypes.c_void_p * num_inputs)(
|
||||
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
|
||||
)
|
||||
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
|
||||
|
||||
hidden_states = decoders[-1].torch_out[0]
|
||||
new_key_states = []
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import torch
|
||||
import time
|
||||
|
||||
import ctypes
|
||||
from typing import Optional, Sequence, List, Union, Any, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -379,6 +379,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
|
||||
offset = offset + curr_linear_ops
|
||||
|
||||
array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
|
||||
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -402,7 +405,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
|
||||
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
|
||||
inputs,
|
||||
decoders=self.backend_decoders)
|
||||
self.backend_decoders,
|
||||
self.models_ptr)
|
||||
|
||||
if self.do_print:
|
||||
print("outputs:", hidden_states)
|
||||
|
|
|
|||
Loading…
Reference in a new issue