Fix 083 lm_head error (#13132)

* fix no quantize error

* update

* update style
This commit is contained in:
Wang, Jian4 2025-05-06 15:47:20 +08:00 committed by GitHub
parent 685a749adb
commit 01bc7e9eb9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View file

@ -94,10 +94,9 @@ else:
dtype = "float16"
if "gemma-3" in model_path:
mm_processor_kwarg = {"do_pan_and_scan": True}
dtype = "float32"
else:
mm_processor_kwarg = None
pass
llm = LLM(
@ -106,7 +105,7 @@ llm = LLM(
dtype=dtype,
enforce_eager=True,
hf_overrides=hf_override,
mm_processor_kwargs=mm_processor_kwarg,
mm_processor_kwargs=None,
load_in_low_bit="sym_int4",
tensor_parallel_size=2,
disable_async_output_proc=True,

View file

@ -41,9 +41,9 @@ def _sample_get_logits(
# HINT: we do not support other types of quantization for now
# TODO: we may encounter tie-word-embedding problems
if isinstance(lm_head, VocabParallelEmbedding):
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
else:
logits = lm_head(hidden_states)
if embedding_bias is not None: