Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the possibility to store only cubin/hsaco files in the cache directory #5827

Merged
merged 14 commits into from
Feb 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def compile(src, target=None, options=None):
# core changes to make it easier to track kernels by hash.
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
Expand Down Expand Up @@ -284,7 +285,9 @@ def compile(src, target=None, options=None):
if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
print(f"\nOverriding kernel with file {full_name}")
next_module = parse(full_name, ext, context)
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
if fn_dump_manager is not None:
fn_dump_manager.put(next_module, ir_filename)
# use an env variable to parse ir from file
Expand Down
Loading