fix fp16 (#9970)
This commit is contained in:
		
							parent
							
								
									052962dfa5
								
							
						
					
					
						commit
						da4687c917
					
				
					 1 changed files with 4 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -97,6 +97,8 @@ def qwen_attention_forward(
 | 
			
		|||
            rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
            if use_fuse_rope:
 | 
			
		||||
                cos, sin = rotary_pos_emb
 | 
			
		||||
                cos = cos.to(query.dtype)
 | 
			
		||||
                sin = sin.to(query.dtype)
 | 
			
		||||
                query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
 | 
			
		||||
            else:
 | 
			
		||||
                rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +113,8 @@ def qwen_attention_forward(
 | 
			
		|||
                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
                if use_fuse_rope:
 | 
			
		||||
                    cos, sin = rotary_pos_emb
 | 
			
		||||
                    cos = cos.to(query.dtype)
 | 
			
		||||
                    sin = sin.to(query.dtype)
 | 
			
		||||
                    query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
 | 
			
		||||
                    query_list += [query]
 | 
			
		||||
                    key_list += [key]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue