Enable fused layernorm (#9614)
* bloom layernorm * fix * layernorm * fix * fix * fix * style fix * fix * replace nn.LayerNorm
This commit is contained in:
		
							parent
							
								
									84a19705a6
								
							
						
					
					
						commit
						82255f9726
					
				
					 2 changed files with 19 additions and 0 deletions
				
			
		| 
						 | 
					@ -392,6 +392,12 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        # todo implement 4.28.0 ~ 4.30.2
 | 
					        # todo implement 4.28.0 ~ 4.30.2
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # convert all nn.LayerNorm
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers.models.bloom import bloom_layer_norm_forward
 | 
				
			||||||
 | 
					    convert_forward(model,
 | 
				
			||||||
 | 
					                    nn.LayerNorm,
 | 
				
			||||||
 | 
					                    bloom_layer_norm_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
 | 
					    if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
 | 
				
			||||||
        if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
 | 
					        if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
 | 
				
			||||||
            # chatglm2-6b-32k
 | 
					            # chatglm2-6b-32k
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,6 +62,19 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
 | 
				
			||||||
    return out
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def bloom_layer_norm_forward(self, hidden_states):
 | 
				
			||||||
 | 
					    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
 | 
				
			||||||
 | 
					                                                     [self.weight.size(0)],
 | 
				
			||||||
 | 
					                                                     self.weight,
 | 
				
			||||||
 | 
					                                                     self.bias,
 | 
				
			||||||
 | 
					                                                     self.eps)
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def bloom_attention_forward(
 | 
					def bloom_attention_forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        hidden_states: torch.Tensor,
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue