Combine apply_rotary_pos_emb for gpt-neox (#9074)
This commit is contained in:
parent
0b40ef8261
commit
78ea7ddb1c
2 changed files with 2 additions and 10 deletions
|
|
@ -5,7 +5,7 @@ Before running, make sure to have [bigdl-llm](../../../README.md) and [bigdl-nan
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
```bash
|
```bash
|
||||||
pip install omageconfig
|
pip install omegaconf
|
||||||
pip install pandas
|
pip install pandas
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def rotate_every_two(x):
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
||||||
if model_family in ["llama", "baichuan", "internlm", "aquila"]:
|
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
|
@ -86,14 +86,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
||||||
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
elif model_family == "gpt_neox":
|
|
||||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
|
||||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
|
||||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
|
||||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
else:
|
else:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"{model_family} is not supported.")
|
f"{model_family} is not supported.")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue