From 78ea7ddb1c505f7b5e6909ea87085fc0ef237ec5 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Sat, 7 Oct 2023 16:27:46 +0800 Subject: [PATCH] Combine apply_rotary_pos_emb for gpt-neox (#9074) --- python/llm/dev/benchmark/all-in-one/README.md | 2 +- python/llm/src/bigdl/llm/transformers/models/utils.py | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/python/llm/dev/benchmark/all-in-one/README.md b/python/llm/dev/benchmark/all-in-one/README.md index f40da6b0..d5b5457b 100644 --- a/python/llm/dev/benchmark/all-in-one/README.md +++ b/python/llm/dev/benchmark/all-in-one/README.md @@ -5,7 +5,7 @@ Before running, make sure to have [bigdl-llm](../../../README.md) and [bigdl-nan ## Dependencies ```bash -pip install omageconfig +pip install omegaconf pip install pandas ``` diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 1aed301f..0357ac2a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -71,7 +71,7 @@ def rotate_every_two(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): - if model_family in ["llama", "baichuan", "internlm", "aquila"]: + if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -86,14 +86,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family): q_embed = (q * cos) + (rotate_every_two(q) * sin) k_embed = (k * cos) + (rotate_every_two(k) * sin) return q_embed, k_embed - elif model_family == "gpt_neox": - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed else: invalidInputError(False, f"{model_family} is not supported.")