optimize npu qwen2 (#12107)

This commit is contained in:
Ruonan Wang 2024-09-20 04:46:16 -07:00 committed by GitHub
parent 02399021d6
commit 03bd01c99c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 9 deletions

View file

@ -399,18 +399,18 @@ class LLMBaseNNFactory(NNFactory):
self.setWeights(offset, op_id, *weights)
@staticmethod
def run_decoders(inputs, decoders):
def run_decoders(inputs, decoders, models_ptr=None):
x_np = [elem.to(torch.float16).numpy() for elem in inputs]
num_decoders = len(decoders)
num_inputs = len(x_np)
with record_function(f"npu_factory"):
if models_ptr is None:
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
models_ptr = array_type(
*[decoders[i]._mm for i in range(num_decoders)]
)
inputs_ptr = (ctypes.c_void_p * num_inputs)(
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
)

View file

@ -17,7 +17,7 @@
import os
import torch
import time
import ctypes
from typing import Optional, Sequence, List, Union, Any, Tuple
import numpy as np
@ -379,6 +379,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
offset = offset + curr_linear_ops
array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])
def forward(
self,
hidden_states: torch.Tensor,
@ -402,7 +405,8 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
inputs,
decoders=self.backend_decoders)
self.backend_decoders,
self.models_ptr)
if self.do_print:
print("outputs:", hidden_states)