Combine apply_rotary_pos_emb for gpt-neox (#9074)
This commit is contained in:
		
							parent
							
								
									0b40ef8261
								
							
						
					
					
						commit
						78ea7ddb1c
					
				
					 2 changed files with 2 additions and 10 deletions
				
			
		| 
						 | 
					@ -5,7 +5,7 @@ Before running, make sure to have [bigdl-llm](../../../README.md) and [bigdl-nan
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Dependencies
 | 
					## Dependencies
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
pip install omageconfig
 | 
					pip install omegaconf
 | 
				
			||||||
pip install pandas
 | 
					pip install pandas
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,7 +71,7 @@ def rotate_every_two(x):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
					def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
				
			||||||
    if model_family in ["llama", "baichuan", "internlm", "aquila"]:
 | 
					    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
 | 
				
			||||||
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
 | 
					        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
 | 
				
			||||||
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
					        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
				
			||||||
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
					        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
				
			||||||
| 
						 | 
					@ -86,14 +86,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
				
			||||||
        q_embed = (q * cos) + (rotate_every_two(q) * sin)
 | 
					        q_embed = (q * cos) + (rotate_every_two(q) * sin)
 | 
				
			||||||
        k_embed = (k * cos) + (rotate_every_two(k) * sin)
 | 
					        k_embed = (k * cos) + (rotate_every_two(k) * sin)
 | 
				
			||||||
        return q_embed, k_embed
 | 
					        return q_embed, k_embed
 | 
				
			||||||
    elif model_family == "gpt_neox":
 | 
					 | 
				
			||||||
        gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
 | 
					 | 
				
			||||||
        gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
 | 
					 | 
				
			||||||
        cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
 | 
					 | 
				
			||||||
        sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
 | 
					 | 
				
			||||||
        q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
					 | 
				
			||||||
        k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
					 | 
				
			||||||
        return q_embed, k_embed
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        invalidInputError(False,
 | 
					        invalidInputError(False,
 | 
				
			||||||
                          f"{model_family} is not supported.")
 | 
					                          f"{model_family} is not supported.")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue