parent
8edcdeb0e7
commit
14b1e6b699
2 changed files with 32 additions and 20 deletions
|
|
@ -449,40 +449,50 @@ def run_transformer_int4_gpu(repo_id,
|
|||
model_path = get_model_path(repo_id, local_model_hub)
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
if fp16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = 'auto'
|
||||
st = time.perf_counter()
|
||||
origin_repo_id = repo_id.replace("-4bit", "")
|
||||
if origin_repo_id in CHATGLM_IDS:
|
||||
if "4bit" in repo_id:
|
||||
model = AutoModel.load_low_bit(model_path, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True,
|
||||
cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True).eval()
|
||||
trust_remote_code=True, use_cache=True,
|
||||
torch_dtype=torch_dtype).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cpu_embedding=cpu_embedding)
|
||||
elif origin_repo_id in LLAMA_IDS:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
|
||||
use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch_dtype).eval()
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
else:
|
||||
if "4bit" in repo_id:
|
||||
model = AutoModelForCausalLM.load_low_bit(model_path, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True,
|
||||
cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
if 'starcoder' in repo_id:
|
||||
# Load starcoder-15.5b model in bf16 format to avoid CPU OOM.
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding, torch_dtype=torch.bfloat16).eval()
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch.bfloat16 if not fp16 else torch.float16).eval()
|
||||
# Convert the low-bit model back to fp32 for performance considerations.
|
||||
model = model.float()
|
||||
if not fp16:
|
||||
model = model.float()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True,
|
||||
cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch_dtype).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
print("Convert model to half precision")
|
||||
|
||||
model = model.to('xpu')
|
||||
|
||||
end = time.perf_counter()
|
||||
|
|
@ -984,30 +994,30 @@ def run_transformer_int4_fp16_gpu_win(repo_id,
|
|||
st = time.perf_counter()
|
||||
if repo_id in CHATGLM_IDS:
|
||||
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch.float16).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = model.half()
|
||||
model = model.to('xpu')
|
||||
elif repo_id in LLAMA_IDS:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch.float16).eval()
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = model.half()
|
||||
model = model.to('xpu')
|
||||
elif repo_id in LLAVA_IDS:
|
||||
llava_repo_dir = os.environ.get('LLAVA_REPO_DIR')
|
||||
sys.path.append(rf"{llava_repo_dir}")
|
||||
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch.float16).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = model.half()
|
||||
model = model.to('xpu')
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding).eval()
|
||||
trust_remote_code=True, use_cache=True, cpu_embedding=cpu_embedding,
|
||||
torch_dtype=torch.float16).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = model.half()
|
||||
model = model.to('xpu')
|
||||
end = time.perf_counter()
|
||||
load_time = end - st
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ class LowBitEmbedding(torch.nn.Embedding):
|
|||
requires_grad=False,
|
||||
quantized=False, _shape=None, qtype=qtype)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_embeddings = num_embeddings
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
|
|
@ -110,5 +111,6 @@ class LowBitEmbedding(torch.nn.Embedding):
|
|||
"Please `pip install bigdl_core_xe` first.")
|
||||
|
||||
result = xe_linear.dequantize_rows(x.contiguous(), self.weight.data,
|
||||
self.weight.qtype, self.embedding_dim)
|
||||
self.weight.qtype, self.embedding_dim,
|
||||
self.num_embeddings)
|
||||
return result.to(self.torch_dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue