Fix 083 lm_head error (#13132)
* fix no quantize error * update * update style
This commit is contained in:
parent
685a749adb
commit
01bc7e9eb9
2 changed files with 5 additions and 6 deletions
|
|
@ -94,10 +94,9 @@ else:
|
||||||
|
|
||||||
dtype = "float16"
|
dtype = "float16"
|
||||||
if "gemma-3" in model_path:
|
if "gemma-3" in model_path:
|
||||||
mm_processor_kwarg = {"do_pan_and_scan": True}
|
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
else:
|
else:
|
||||||
mm_processor_kwarg = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
|
|
@ -106,7 +105,7 @@ llm = LLM(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
hf_overrides=hf_override,
|
hf_overrides=hf_override,
|
||||||
mm_processor_kwargs=mm_processor_kwarg,
|
mm_processor_kwargs=None,
|
||||||
load_in_low_bit="sym_int4",
|
load_in_low_bit="sym_int4",
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
disable_async_output_proc=True,
|
disable_async_output_proc=True,
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def _sample_get_logits(
|
||||||
# HINT: we do not support other types of quantization for now
|
# HINT: we do not support other types of quantization for now
|
||||||
# TODO: we may encounter tie-word-embedding problems
|
# TODO: we may encounter tie-word-embedding problems
|
||||||
if isinstance(lm_head, VocabParallelEmbedding):
|
if isinstance(lm_head, VocabParallelEmbedding):
|
||||||
logits = lm_head.linear_method.apply(lm_head,
|
logits = lm_head.quant_method.apply(lm_head,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
bias=embedding_bias)
|
bias=embedding_bias)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue