Enable fused layernorm (#9614)
* bloom layernorm * fix * layernorm * fix * fix * fix * style fix * fix * replace nn.LayerNorm
This commit is contained in:
parent
84a19705a6
commit
82255f9726
2 changed files with 19 additions and 0 deletions
|
|
@ -392,6 +392,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# convert all nn.LayerNorm
|
||||||
|
from bigdl.llm.transformers.models.bloom import bloom_layer_norm_forward
|
||||||
|
convert_forward(model,
|
||||||
|
nn.LayerNorm,
|
||||||
|
bloom_layer_norm_forward)
|
||||||
|
|
||||||
if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||||
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
|
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
|
||||||
# chatglm2-6b-32k
|
# chatglm2-6b-32k
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,19 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def bloom_layer_norm_forward(self, hidden_states):
|
||||||
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
|
import linear_q4_0
|
||||||
|
hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
|
||||||
|
[self.weight.size(0)],
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
def bloom_attention_forward(
|
def bloom_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue