use new fused layer norm (#12553)
This commit is contained in:
parent
680ea7e4a8
commit
a608f26cc8
4 changed files with 38 additions and 41 deletions
|
|
@ -1296,10 +1296,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
trans_version = transformers.__version__
|
trans_version = transformers.__version__
|
||||||
|
|
||||||
# convert all nn.LayerNorm
|
# convert all nn.LayerNorm
|
||||||
from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward
|
from ipex_llm.transformers.models.common import layer_norm_forward
|
||||||
convert_forward(model,
|
convert_forward(model, nn.LayerNorm, layer_norm_forward)
|
||||||
nn.LayerNorm,
|
|
||||||
bloom_layer_norm_forward)
|
|
||||||
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
|
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.llama import llama_mlp_forward
|
from ipex_llm.transformers.models.llama import llama_mlp_forward
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,23 +64,6 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def bloom_layer_norm_forward(self, hidden_states):
|
|
||||||
if use_fused_layer_norm(hidden_states, self.training):
|
|
||||||
import xe_addons
|
|
||||||
result = xe_addons.fused_layer_norm(hidden_states,
|
|
||||||
[self.weight.size(0)],
|
|
||||||
self.weight,
|
|
||||||
self.bias,
|
|
||||||
self.eps)
|
|
||||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
|
||||||
if result.nelement != 0:
|
|
||||||
return result
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
result = F.layer_norm(hidden_states.to(self.weight.dtype),
|
|
||||||
self.normalized_shape, self.weight, self.bias, self.eps)
|
|
||||||
return result.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def bloom_attention_forward(
|
def bloom_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
@ -159,7 +160,7 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
|
||||||
else:
|
else:
|
||||||
eps = self.epsilon
|
eps = self.epsilon
|
||||||
|
|
||||||
if hidden_states.device.type == 'xpu':
|
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
|
||||||
import xe_addons
|
import xe_addons
|
||||||
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
||||||
output = xe_addons.rms_norm(weight, x_2d, eps)
|
output = xe_addons.rms_norm(weight, x_2d, eps)
|
||||||
|
|
@ -169,3 +170,17 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
||||||
return weight * hidden_states.to(input_dtype)
|
return weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm_forward(self, hidden_states: torch.Tensor):
|
||||||
|
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
|
||||||
|
import xe_addons
|
||||||
|
hidden_size = math.prod(self.normalized_shape)
|
||||||
|
x_2d = hidden_states.reshape(-1, hidden_size).contiguous()
|
||||||
|
output = xe_addons.layer_norm(x_2d, self.weight, self.bias, self.eps)
|
||||||
|
return output.reshape(hidden_states.shape)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.layer_norm(
|
||||||
|
hidden_states, self.normalized_shape,
|
||||||
|
self.weight, self.bias, self.eps
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -113,5 +113,5 @@ class Test_Optimize_Gpu_Model:
|
||||||
# currently only compare the output of the last LayerNorm layer.
|
# currently only compare the output of the last LayerNorm layer.
|
||||||
layer_before_LayerNorm = "transformer.h.30"
|
layer_before_LayerNorm = "transformer.h.30"
|
||||||
LayerNorm_layer = "transformer.h.31.input_layernorm"
|
LayerNorm_layer = "transformer.h.31.input_layernorm"
|
||||||
lower_bound = 0
|
lower_bound = 1e-5
|
||||||
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, LayerNorm_layer, layer_before_LayerNorm, lower_bound)
|
self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, LayerNorm_layer, layer_before_LayerNorm, lower_bound)
|
||||||
Loading…
Reference in a new issue