diff --git a/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py b/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py index 1c9a5e3c..d6221cc2 100644 --- a/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py +++ b/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py @@ -55,6 +55,7 @@ import ctypes from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple from collections import deque, OrderedDict from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.ggml.model.generation import GenerationMixin from . import gptneox_cpp from .gptneox_types import * @@ -121,7 +122,7 @@ class GptneoxState: self.gptneox_state_size = gptneox_state_size -class Gptneox: +class Gptneox(GenerationMixin): """High-level Python wrapper for a gptneox.cpp model.""" def __init__(