Skip to content

Commit

Permalink
Enable warmup also for full prompt length case in text generation (#1676
Browse files Browse the repository at this point in the history
)
  • Loading branch information
yeonsily authored Jan 17, 2025
1 parent f10d5b0 commit 24addbc
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,22 +719,21 @@ def generate_dataset(batch):
return prompt, outputs

# warmup
if prompt_length > 0:
from optimum.habana.utils import HabanaProfile
from optimum.habana.utils import HabanaProfile

# compilation stage disable profiling
HabanaProfile.disable()
# Compilation
logger.info("Graph compilation...")
t0 = time.perf_counter()
for i, batch in enumerate(dataloader):
generate_dataset(batch)
# The first three iterations take longer because of graph compilation
if (i + 1) == 3:
break
torch_hpu.synchronize()
compilation_duration = time.perf_counter() - t0
HabanaProfile.enable()
# compilation stage disable profiling
HabanaProfile.disable()
# Compilation
logger.info("Graph compilation...")
t0 = time.perf_counter()
for i, batch in enumerate(dataloader):
generate_dataset(batch)
# The first three iterations take longer because of graph compilation
if (i + 1) == 3:
break
torch_hpu.synchronize()
compilation_duration = time.perf_counter() - t0
HabanaProfile.enable()

total_new_tokens_generated = 0
duration = 0
Expand Down Expand Up @@ -770,8 +769,7 @@ def generate_dataset(batch):
mem = get_hpu_memory_stats()
for k, v in mem.items():
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
if prompt_length > 0:
print(f"Graph compilation duration = {compilation_duration} seconds")
print(f"Graph compilation duration = {compilation_duration} seconds")
print(separator)
if args.quant_config:
finalize_quantization(model)
Expand Down

0 comments on commit 24addbc

Please sign in to comment.