fix qwen2 cpu (#11240)
This commit is contained in:
		
							parent
							
								
									e738ec38f4
								
							
						
					
					
						commit
						2e4ccd541c
					
				
					 2 changed files with 6 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -1279,6 +1279,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.Qwen2Attention,
 | 
			
		||||
                        qwen2_attention_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.Qwen2SdpaAttention,
 | 
			
		||||
                        qwen2_attention_forward)
 | 
			
		||||
    elif model.config.model_type == "qwen2_moe":
 | 
			
		||||
        # for Qwen1.5-MOE-A2.7B
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -326,6 +326,9 @@ def qwen2_attention_forward(
 | 
			
		|||
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
    if query_states.device.type == "cpu":
 | 
			
		||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
        attn_output = sdpa(query_states,
 | 
			
		||||
                           key_states,
 | 
			
		||||
                           value_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue