optimize npu qwen2 (#12107)

This commit is contained in:
Ruonan Wang 2024-09-20 04:46:16 -07:00 committed by GitHub
parent 02399021d6
commit 03bd01c99c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 9 deletions

View file

@ -399,22 +399,22 @@ class LLMBaseNNFactory(NNFactory):
self.setWeights(offset, op_id, *weights) self.setWeights(offset, op_id, *weights)
@staticmethod @staticmethod
def run_decoders(inputs, decoders): def run_decoders(inputs, decoders, models_ptr=None):
x_np = [elem.to(torch.float16).numpy() for elem in inputs] x_np = [elem.to(torch.float16).numpy() for elem in inputs]
num_decoders = len(decoders) num_decoders = len(decoders)
num_inputs = len(x_np) num_inputs = len(x_np)
with record_function(f"npu_factory"): if models_ptr is None:
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
models_ptr = array_type( models_ptr = array_type(
*[decoders[i]._mm for i in range(num_decoders)] *[decoders[i]._mm for i in range(num_decoders)]
) )
inputs_ptr = (ctypes.c_void_p * num_inputs)(
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np] inputs_ptr = (ctypes.c_void_p * num_inputs)(
) *[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs) )
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)
hidden_states = decoders[-1].torch_out[0] hidden_states = decoders[-1].torch_out[0]
new_key_states = [] new_key_states = []

View file

@ -17,7 +17,7 @@
import os import os
import torch import torch
import time import time
import ctypes
from typing import Optional, Sequence, List, Union, Any, Tuple from typing import Optional, Sequence, List, Union, Any, Tuple
import numpy as np import numpy as np
@ -379,6 +379,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
self.backend_decoders[i].set_weights(self.op_id, curr_parameters) self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
offset = offset + curr_linear_ops offset = offset + curr_linear_ops
array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -402,7 +405,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders( hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
inputs, inputs,
decoders=self.backend_decoders) self.backend_decoders,
self.models_ptr)
if self.do_print: if self.do_print:
print("outputs:", hidden_states) print("outputs:", hidden_states)