diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index e5df7f2c7c..ae04d92970 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -719,22 +719,21 @@ def generate_dataset(batch): return prompt, outputs # warmup - if prompt_length > 0: - from optimum.habana.utils import HabanaProfile + from optimum.habana.utils import HabanaProfile - # compilation stage disable profiling - HabanaProfile.disable() - # Compilation - logger.info("Graph compilation...") - t0 = time.perf_counter() - for i, batch in enumerate(dataloader): - generate_dataset(batch) - # The first three iterations take longer because of graph compilation - if (i + 1) == 3: - break - torch_hpu.synchronize() - compilation_duration = time.perf_counter() - t0 - HabanaProfile.enable() + # compilation stage disable profiling + HabanaProfile.disable() + # Compilation + logger.info("Graph compilation...") + t0 = time.perf_counter() + for i, batch in enumerate(dataloader): + generate_dataset(batch) + # The first three iterations take longer because of graph compilation + if (i + 1) == 3: + break + torch_hpu.synchronize() + compilation_duration = time.perf_counter() - t0 + HabanaProfile.enable() total_new_tokens_generated = 0 duration = 0 @@ -770,8 +769,7 @@ def generate_dataset(batch): mem = get_hpu_memory_stats() for k, v in mem.items(): print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) - if prompt_length > 0: - print(f"Graph compilation duration = {compilation_duration} seconds") + print(f"Graph compilation duration = {compilation_duration} seconds") print(separator) if args.quant_config: finalize_quantization(model)