From cee9eaf54236c47bb55e1c1ac6819162a2dca829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cheen=20Hau=2C=20=E4=BF=8A=E8=B1=AA?= <33478814+chtanch@users.noreply.github.com> Date: Mon, 30 Oct 2023 14:38:34 +0800 Subject: [PATCH] [LLM] Fix llm arc ut oom (#9300) * Move model to cpu after testing so that gpu memory is deallocated * Add code comment --------- Co-authored-by: sgwhat --- python/llm/test/inference_gpu/test_optimize_model.py | 2 ++ python/llm/test/inference_gpu/test_transformers_api.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/llm/test/inference_gpu/test_optimize_model.py b/python/llm/test/inference_gpu/test_optimize_model.py index 515686f7..358f4e14 100644 --- a/python/llm/test/inference_gpu/test_optimize_model.py +++ b/python/llm/test/inference_gpu/test_optimize_model.py @@ -41,6 +41,7 @@ def test_optimize_model(Model, Tokenizer, model_path): trust_remote_code=True) model = model.to(device) logits_base_model = (model(input_ids)).logits + model.to('cpu') # deallocate gpu memory model = Model.from_pretrained(model_path, load_in_4bit=True, @@ -48,6 +49,7 @@ def test_optimize_model(Model, Tokenizer, model_path): trust_remote_code=True) model = model.to(device) logits_optimized_model = (model(input_ids)).logits + model.to('cpu') diff = abs(logits_base_model - logits_optimized_model).flatten() diff --git a/python/llm/test/inference_gpu/test_transformers_api.py b/python/llm/test/inference_gpu/test_transformers_api.py index 6ac9591a..c0f3dfc0 100644 --- a/python/llm/test/inference_gpu/test_transformers_api.py +++ b/python/llm/test/inference_gpu/test_transformers_api.py @@ -41,10 +41,11 @@ def test_completion(Model, Tokenizer, model_path, prompt, answer): load_in_4bit=True, optimize_model=True, trust_remote_code=True) - model = model.to(device) + model = model.to(device) # deallocate gpu memory input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) output = model.generate(input_ids, max_new_tokens=32) + model.to('cpu') output_str = tokenizer.decode(output[0], skip_special_tokens=True) assert answer in output_str