We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
We found the following error in test:
error: undefined reference to `_Z27__spirv_GroupNonUniformIAddiibj' in function: '__spirv_GroupNonUniformIAdd(int, int, bool, unsigned int)' called by kernel: 'triton_per_fused_sort_0' error: backend compiler failed build. Error during Intel loadBinary: Triton Error [ZE]: 0x70000004
Reproducer:
import torch from torch._inductor.async_compile import AsyncCompile import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import ( grid, ) from torch._C import _xpu_getCurrentRawStream as get_raw_stream from torch._C import _xpu_getCurrentRawStream as get_raw_stream aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p # kernel path: /tmp/tmp9any7ia_/t2/ct2jkhg2tcwd7qnbqiwz5lc5lz6dyhgupll6v46pbsuxkmkphohv.py # Topologically Sorted Source Nodes: [msort], Original ATen: [aten.sort] # Source node to ATen node mapping: # msort => sort # Graph fragment: # %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.default](args = (%arg0_1, 0), kwargs = {}) triton_per_fused_sort_0 = async_compile.triton('triton_per_fused_sort_0', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties triton_helpers.set_driver_to_gpu() @triton_heuristics.persistent_reduction( size_hints={'x': 64, 'r0_': 8}, reduction_hint=ReductionHint.DEFAULT, filename=__file__, triton_meta={'signature': {'in_ptr0': '*i1', 'out_ptr0': '*i1', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'driver_version': '1.3.30049+10', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68719476736, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': [(0,), (1,)], 'tt.equal_to': []}, 'cls': 'AttrsDescriptor'})]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_sort_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'FB40EBCFCB9A06A14744E64088225D5F51204F154F792FD0F599F4DD0B53BAD5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} ) @triton.jit def triton_per_fused_sort_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr): xnumel = 50 r0_numel = 5 R0_BLOCK: tl.constexpr = 8 rnumel = r0_numel RBLOCK: tl.constexpr = R0_BLOCK xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel r0_index = tl.arange(0, R0_BLOCK)[None, :] r0_offset = 0 r0_mask = r0_index < r0_numel roffset = r0_offset rindex = r0_index r0_1 = r0_index x0 = xindex tmp0 = tl.load(in_ptr0 + (x0 + 50*r0_1), r0_mask & xmask, other=0.0).to(tl.int1) tmp1 = r0_1 tmp2 = tmp1.to(tl.int16) tl.static_assert(tmp2.dtype == tl.int16) tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, rnumel, 1, stable=False, descending=False) tl.store(out_ptr0 + (x0 + 50*r0_1), tmp5, r0_mask & xmask) ''', device_str='xpu') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (5, 10, 5), (50, 5, 1)) with torch.xpu._DeviceGuard(0): torch.xpu.set_device(0) buf0 = empty_strided_xpu((5, 10, 5), (50, 5, 1), torch.bool) # Topologically Sorted Source Nodes: [msort], Original ATen: [aten.sort] stream0 = get_raw_stream(0) triton_per_fused_sort_0.run(arg0_1, buf0, 50, 5, grid=grid(50), stream=stream0) del arg0_1 return (buf0, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = rand_strided((5, 10, 5), (50, 5, 1), device='xpu:0', dtype=torch.bool) fn = lambda: call([arg0_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module)
Triton: latest main (b59bb9a) PyTorch: latest main (d0f5df83a50d9bb630764c92ac63fcb2640b1f94) + triton patch OS: Ubuntu 24.10 / Windows 11
The text was updated successfully, but these errors were encountered:
i1
victor-eds
Successfully merging a pull request may close this issue.
Describe the bug
We found the following error in test:
Reproducer:
Environment details
Triton: latest main (b59bb9a)
PyTorch: latest main (d0f5df83a50d9bb630764c92ac63fcb2640b1f94) + triton patch
OS: Ubuntu 24.10 / Windows 11
The text was updated successfully, but these errors were encountered: