LLM: fix AttributeError of FP16Linear (#10740)
This commit is contained in:
parent
1256a2cc4e
commit
70ed9397f9
1 changed files with 4 additions and 2 deletions
|
|
@ -729,7 +729,7 @@ class LowBitLinear(nn.Linear):
|
||||||
|
|
||||||
class FP16Linear(nn.Linear):
|
class FP16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, weight_type=1,
|
mp_group=None, weight_type=1, enable_xetla=False,
|
||||||
optimize_lm_head=False):
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
|
|
@ -743,6 +743,7 @@ class FP16Linear(nn.Linear):
|
||||||
# weigh_type = 3 means weight has been transposed by esimd method
|
# weigh_type = 3 means weight has been transposed by esimd method
|
||||||
self.weight_type = 1
|
self.weight_type = 1
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
self.enable_xetla = enable_xetla
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# only work for GPU
|
# only work for GPU
|
||||||
|
|
@ -849,7 +850,7 @@ class FP16Linear(nn.Linear):
|
||||||
|
|
||||||
class BF16Linear(nn.Linear):
|
class BF16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, compute_dtype=None,
|
mp_group=None, compute_dtype=None, enable_xetla=False,
|
||||||
optimize_lm_head=False):
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
|
|
@ -860,6 +861,7 @@ class BF16Linear(nn.Linear):
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
self.optimize_lm_head = optimize_lm_head
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
self.enable_xetla = enable_xetla
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.optimize_lm_head:
|
if self.optimize_lm_head:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue