Fix 083 lm_head error (#13132)

* fix no quantize error

* update

* update style
This commit is contained in:
Wang, Jian4 2025-05-06 15:47:20 +08:00 committed by GitHub
parent 685a749adb
commit 01bc7e9eb9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View file

@ -94,10 +94,9 @@ else:
dtype = "float16" dtype = "float16"
if "gemma-3" in model_path: if "gemma-3" in model_path:
mm_processor_kwarg = {"do_pan_and_scan": True}
dtype = "float32" dtype = "float32"
else: else:
mm_processor_kwarg = None pass
llm = LLM( llm = LLM(
@ -106,7 +105,7 @@ llm = LLM(
dtype=dtype, dtype=dtype,
enforce_eager=True, enforce_eager=True,
hf_overrides=hf_override, hf_overrides=hf_override,
mm_processor_kwargs=mm_processor_kwarg, mm_processor_kwargs=None,
load_in_low_bit="sym_int4", load_in_low_bit="sym_int4",
tensor_parallel_size=2, tensor_parallel_size=2,
disable_async_output_proc=True, disable_async_output_proc=True,

View file

@ -41,7 +41,7 @@ def _sample_get_logits(
# HINT: we do not support other types of quantization for now # HINT: we do not support other types of quantization for now
# TODO: we may encounter tie-word-embedding problems # TODO: we may encounter tie-word-embedding problems
if isinstance(lm_head, VocabParallelEmbedding): if isinstance(lm_head, VocabParallelEmbedding):
logits = lm_head.linear_method.apply(lm_head, logits = lm_head.quant_method.apply(lm_head,
hidden_states, hidden_states,
bias=embedding_bias) bias=embedding_bias)
else: else: