-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ukernel selection logic + clean up KleidiAI integration #1652
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1652
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 21cef83 with merge base 22d7d51 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -98,7 +98,7 @@ LinearTilingParams get_default_linear_tiling_params( | |||
TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); | |||
|
|||
tiling_params.mc_by_mr = 1; | |||
int mc = tiling_params.mc_by_mr * ukernel_config.mr; | |||
int mc = tiling_params.mc_by_mr * ukernel_config.kernels[0].mr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ukernel_config now includes an array of kernels based on mr. Still need to add mr selection logic here, for now it just selects the first one.
static UKernelConfigCacheType ukernel_config_cache; | ||
|
||
// Check cache | ||
auto it = ukernel_config_cache.find(header); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want uarch specific kernel per core, we can add uarch to cache key and look up uarch before looking in cache, e.g.,
auto uarch = get_current_core_uarch();
auto it = ukernel_config_cache.find({header, uarch});
torchao/experimental/CMakeLists.txt
Outdated
@@ -22,7 +22,7 @@ if(NOT TORCHAO_INCLUDE_DIRS) | |||
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) | |||
endif() | |||
|
|||
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) | |||
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" ON) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: nocommit
# print(f"actual_val={actual_val}, expected_val={expected_val}") | ||
# self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) | ||
|
||
self.assertTrue(torch.abs(actual_val - expected_val) < 0.05) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not commit change. This is because kleidi has bf16 instead of fp32.
0}); | ||
} | ||
|
||
struct KleidiAIPackingParams { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: check if these packing params are sufficient for all kleidi.
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ | ||
/*preferred_alignment*/16, | ||
/*weight_packing*/ | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can rework the kleidiai integration to share weight packing, rather than repeat in each namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is shared in code, but exposed along with the kernel so you don't have to map it back to the kernel at call sites.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is in shared code, but not in a way that is convenient to access with shared mr kernels because the same packing function (indexed by nr, kr, sr) is given 4 different names (based on namespace).
So we could refactor it to make one packing function in kai_matmul_clamp_f32_qai8dxp_qsi4c32p, rather than have them in further specific namespaces?
/*kernels*/ | ||
{{ | ||
{ | ||
/*mr*/static_cast<int>(uk.get_m_step()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List of methods index by mr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start. If you can also think some more about code organization, for taking a lot more kernels, and scalability in general.
@@ -8,13 +8,23 @@ cmake_minimum_required(VERSION 3.19) | |||
|
|||
include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake) | |||
|
|||
add_compile_options(-Wno-unused-function -Wno-unused-variable) # For some reason cpuinfo package has unused functions/variables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix it upstream?
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ | ||
/*preferred_alignment*/16, | ||
/*weight_packing*/ | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is shared in code, but exposed along with the kernel so you don't have to map it back to the kernel at call sites.
assert (sr == uk.get_sr()); | ||
|
||
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ | ||
/*preferred_alignment*/16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
/*preferred_alignment*/16, | |
/*preferred_alignment*/uk.get_preferred_alignment(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bucket_size = get_bucket_size(uarch)
if bucket_size == 0 && cpu_info_has_i8mm() {
}
#if defined(TORCHAO_ENABLE_KLEIDI) | ||
if (!target || *target == "kleidi_ai") { | ||
if (weight_nbit == 4 && !has_weight_zeros) { | ||
return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future we would have to make a choice for nr
based on a cpu type (or some static choice for AOT-weight-packing like this), and register [mr] kernels, which you are already planning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can use any method in cpuinfo to select packed_weights_format, including any packing params like nr. This is not entirely static because universal is only selected if cpuinfo_has_arm_neon_dot is available. We could also use fields from uarch to select things here I guess?
I wonder if we should pass n and k as params in addition to target. Implementers can then take into account matrix size when selecting nr?
// ukernel must behave correctly no matter how buffers are aligned | ||
size_t preferred_alignment{0}; | ||
weight_packing_config weight_packing; | ||
std::array<kernel_config, 4> kernels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
std::array<kernel_config, 4> kernels; | |
std::array<kernel_config, MAX_MR_TYPES> kernels; |
weight_data_size_fn_type weight_data_size_fn{nullptr}; | ||
prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; | ||
}; | ||
struct kernel_config { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense that you have one packing kernel, and for which N gemm kernels index by mr, but the naming makes this confusing to read i.e. ukernel->kernel[mr].mr
// preferred_alignment for activation and weight data | ||
// Integration surfaces are not required to respect this alignment, and the | ||
// ukernel must behave correctly no matter how buffers are aligned | ||
size_t preferred_alignment{0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to make sure this is same for all MRs, i.e. document, test
/*kernels*/ | ||
{{ | ||
{ | ||
/*mr*/static_cast<int>(uk.get_m_step()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future, when querying for mr(s), we should ensure their weight packing function pointer is same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes to the comment about reworking the kleidiAI integration I guess?
Let me give it some more thought about breaking some code out. |
Adding @kimishpatel because he was curious about the PR. kernel_selector.h is the main code to pay attention to for runtime kernel selection. |
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) { | ||
static UKernelConfigRegistrationTable table; | ||
|
||
// In future, we can populate this with the current thread's uarch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added uarch to kernel selection cache, although it currently is just set to unknown, so cache is effectively based on format.
|
||
|
||
|
||
DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@digantdesai this file is my draft reworking of the kleidiai integration. Weight packing and activation functions are no longer in isa kernel-specific namespaces because many kernels share the same routines.
Kernel functions and uconfigs are defined using macros. I would like DEFINE_KERNEL_FNS to be defined by things like mr, nr, instruction (dotprod/i8mm), but I don't follow follow the kleidi naming convention. So now it is indexed by first/suffix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Alternative to this would be to code gen these wrappers at compile-time but this is clean enough.
#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \ | ||
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ | ||
return weight_data_size(nr, kr, sr, n, k, group_size); \ | ||
} \ | ||
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ | ||
void* weight_data, \ | ||
int n, \ | ||
int k, \ | ||
int group_size, \ | ||
const int8_t* weight_qvals, \ | ||
const float* weight_scales, \ | ||
const int8_t* weight_zeros, \ | ||
const float* bias) { \ | ||
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torture!
#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \ | |
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ | |
return weight_data_size(nr, kr, sr, n, k, group_size); \ | |
} \ | |
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ | |
void* weight_data, \ | |
int n, \ | |
int k, \ | |
int group_size, \ | |
const int8_t* weight_qvals, \ | |
const float* weight_scales, \ | |
const int8_t* weight_zeros, \ | |
const float* bias) { \ | |
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ | |
} | |
#define DEFINE_WEIGHT_DATA_FN(nr, kr, sr) | |
\ | |
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ | |
return weight_data_size(nr, kr, sr, n, k, group_size); \ | |
} \ | |
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ | |
void* weight_data, \ | |
int n, \ | |
int k, \ | |
int group_size, \ | |
const int8_t* weight_qvals, \ | |
const float* weight_scales, \ | |
const int8_t* weight_zeros, \ | |
const float* bias) { \ | |
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ | |
} |
#define DEFINE_KERNEL_FNS(first, suffix) \ | ||
namespace impl_##suffix { \ | ||
const Ukernel get_ukernel() { \ | ||
return Ukernel{ \ | ||
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing in first ir lhs as qai8dxp1x8
instead of 1x8
is better for (1) meaningful, (2) can cover channel wise 4b quant i.e. QC4W as well.
Also suffix
should be three different things, rhs
+ output tile x kacc
+ isa
, where rhs is not 8x8
but qsi4c32p4x8
.
#define DEFINE_KERNEL_FNS(first, suffix) \ | |
namespace impl_##suffix { \ | |
const Ukernel get_ukernel() { \ | |
return Ukernel{ \ | |
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | |
#define DEFINE_KLEIDI_KERNEL_FN(lhs, suffix) \ | |
namespace impl_##suffix { \ | |
const Ukernel get_ukernel() { \ | |
return Ukernel{ \ | |
.get_m_step = kai_get_m_step_matmul_clamp_f32_##lhs##_##suffix, \ |
|
||
|
||
|
||
DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Alternative to this would be to code gen these wrappers at compile-time but this is clean enough.
} | ||
|
||
// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc | ||
// But I don't quite follow the naming convention for KleidiAI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming convention - kai_matmul_<fused_ops>_<dst_info>_<lhs_info>_<rhs_info>_<mr x nr x kacc>_<technology>_<feature>_<instruction>
89b7b10
to
9d5e7c7
Compare
@digantdesai I can rebase this PR on #1723 (which contains formatting changes based on fbcode formatter). That should make it easier to review because many of the changes are just formatting. This PR cleans up KleidiAI integration:
The purpose for cleaning up the KleidiAI integration is to make the ukernel selection logic cleaner, the main purpose of this PR. |
9d5e7c7
to
21cef83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks Scott. Left some comments.
sh build_and_run_tests.sh | ||
rm -rf /tmp/cmake-out | ||
popd | ||
- name: Run torchao/experimental/ops/tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } | ||
namespace internal { | ||
|
||
inline size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit move somewhere in more shared utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created issue: #1744
roundup isn't just used here, but in other places as well and I'd rather unify them as part of one effort .
} \ | ||
} | ||
|
||
DEFINE_KERNEL_STRUCT( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we should wrap these in TORCHAO_ENABLE_ARM_DOTPROD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed issue: #1743
#endif // TORCHAO_ENABLE_ARM_I8MM | ||
|
||
if (cpuinfo_has_arm_neon_dot()) { | ||
constexpr int n_step = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunate that we can't do get_n_step()
or get_n_step(nr)
here :\
} | ||
} | ||
|
||
// Not thread safe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can add thread safety when its needed. It's currently used on the main thread only.
// Note, cpuinfo_get_current_core() is not currently implemeted outside of | ||
// linux XNNPACK often uses non-core specific logic like | ||
// cpuinfo_get_core(0)->uarch in configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drop or move it to commit msg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather leave it for now, especially if we plan to add uarch differentiation soon. Otherwise, someone might try to do: cpuinfo_get_current_core()->uarch, with bad results on Apple platforms.
(!has_weight_zeros)) { // TODO: add has_bias here | ||
return PackedWeightsFormat( | ||
torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, | ||
has_weight_zeros, /*has_bias*/ true, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check has_bias == True and use that? Re. your TODO comment, do you mean the wiring from TorchAO to the Op for the bias?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment on bias issue to address this: #1675
has_bias is always false right now, so this will never be selected if we reply on has_bias. But if a null bias ptr is passed to KleidiAI, we construct a bias of zeros and include it in the packed weights.
@@ -31,7 +31,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( | |||
assert(nc >= 1); | |||
|
|||
// Replace nc with the next number nr divides | |||
nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; | |||
nc = ((nc + nr - 1) / nr) * nr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nc = ((nc + nr - 1) / nr) * nr; | |
nc = roundup(nc, nr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think roundup is defined here. Created issue on creating shared utils for things like roundup to live above.
int weight_data_offset = | ||
(n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size); | ||
(n_idx / nr) * ukernel_config.weight_packing_config.weight_data_size_fn( | ||
nr, k, group_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the future, not sure if we want to assume this i.e. one can pack weights differently which can break this, ideally we should have an API for this which kernels can overwrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make it part of the config as a follow up: e.g., ukernel_config.weight_offset_fn(n_idx, nr, k, group_size).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created feature request: #1745
/*bias*/ nullptr); | ||
packed_weights_header.write(packed_weights.mutable_data_ptr<int8_t>()); | ||
|
||
// TODO: support passing in bias in future |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already an issue for it: #1675
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks Scott. Left some comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks Scott. Left some comments.
This is a draft to do ukernel selection based on cpu_info.
This relates to #1376