diff --git a/python/llm/test/inference/test_transformers_api.py b/python/llm/test/inference/test_transformers_api.py index f04f4257..d4460336 100644 --- a/python/llm/test/inference/test_transformers_api.py +++ b/python/llm/test/inference/test_transformers_api.py @@ -166,7 +166,7 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt): logits_optimized_model = (model(input_ids)).logits diff = abs(logits_base_model - logits_optimized_model).flatten() - assert any(diff) is False + assert (diff/logits_base_model.flatten()).mean()<0.05 if __name__ == '__main__':