Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
peishenyan committed Jan 17, 2025
1 parent 4ace558 commit 03c0a1e
Showing 1 changed file with 65 additions and 57 deletions.
122 changes: 65 additions & 57 deletions onnxruntime/core/providers/webnn/builders/impl/attention_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,69 @@ class AttentionOpBuilder : public BaseOpBuilder {
std::vector<int64_t> generate_indices(int64_t batch_size, int64_t num_heads, int64_t sequence_length) {
std::vector<int64_t> indices;
for (int64_t i = 0; i < sequence_length; ++i) {
for (int64_t j = 0; j < batch_size * num_heads; ++j) {
indices.push_back(j / num_heads);
indices.push_back(j % num_heads);
}
for (int64_t j = 0; j < batch_size * num_heads; ++j) {
indices.push_back(j / num_heads);
indices.push_back(j % num_heads);
}
}
return indices;
}

std::vector<int64_t> repeat_sequence(int64_t sequence_length, int64_t num_heads, int64_t batch_size) {
std::vector<int64_t> repeated;
for (int64_t i = 0; i < sequence_length; ++i) {
for (int64_t j = 0; j < batch_size * num_heads; ++j) {
repeated.push_back(i);
}
for (int64_t j = 0; j < batch_size * num_heads; ++j) {
repeated.push_back(i);
}
}
return repeated;
}

/** GroupQueryAttention SubGraph.
Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size
N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H)
B and S could be symbolic. ? means it is optional.
GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length
Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.
query key value
| | |
q_Reshape k_Reshape v_Reshape (shape=B,S,H,N)
| | |
q_Transpose k_Transpose v_Transpose
(0,2,1,3) (0,2,3,1) (perm=0,2,1,3)
\ / | past_key
\ / | |
present_key<---\----ScatterND <------|--------+
| | | |
| opt_k_transpose? | seqlens_k
\ (0,1,3,2) | |
\ / | +----past_value
qk_MatMul | /
| [B=h] | /
| / | /
qk_Div ScatterND -----> present_value
| |
| /
Add <----------/---------------finfo_min_mask
| /
Softmax /
\ /
\ /
qkv_MatMul
|
Transpose (perm=0,2,1,3)
|
Reshape---(shape=B,S,W)
|
output
*/

Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const Node& node,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
if (op_type != "GroupQueryAttention"){
if (op_type != "GroupQueryAttention") {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}

Expand Down Expand Up @@ -85,7 +125,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

NodeAttrHelper helper(node);
uint32_t kv_num_heads = helper.Get("kv_num_heads", 32);
uint32_t num_heads = helper.Get("num_heads", 32);
uint32_t num_heads = helper.Get("num_heads", 32);
ORT_RETURN_IF_NOT(kv_num_heads == num_heads, "Now GQA only supports kv_num_heads == num_heads");

uint32_t qkv_batch_size = SafeInt<uint32_t>(input_q_shape[0]);
Expand Down Expand Up @@ -128,15 +168,15 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

emscripten::val desc_left = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc_left, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type");
emscripten::val dims_left = emscripten::val::array(std::vector<uint32_t>({qkv_batch_size*num_heads*qkv_sequence_length,2}));
emscripten::val dims_left = emscripten::val::array(std::vector<uint32_t>({qkv_batch_size * num_heads * qkv_sequence_length, 2}));
desc_left.set("dimensions", dims_left);
desc_left.set("shape", dims_left);
emscripten::val left_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(left));
emscripten::val left_constant = model_builder.GetBuilder().call<emscripten::val>("constant", desc_left, left_buffer);

emscripten::val desc_right = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc_right, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type");
emscripten::val dims_right = emscripten::val::array(std::vector<uint32_t>({qkv_batch_size*num_heads*qkv_sequence_length,1}));
emscripten::val dims_right = emscripten::val::array(std::vector<uint32_t>({qkv_batch_size * num_heads * qkv_sequence_length, 1}));
desc_right.set("dimensions", dims_right);
desc_right.set("shape", dims_right);
emscripten::val right_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(right));
Expand All @@ -146,7 +186,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val reshaped_query = model_builder.GetBuilder().call<emscripten::val>("reshape", query_input, emscripten::val::array(reshape_tensor_shape), common_options);

emscripten::val options = emscripten::val::object();
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0,2,1,3})));
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0, 2, 1, 3})));
options.set("label", node.Name() + "/GQA/query/transpose");
emscripten::val new_query = model_builder.GetBuilder().call<emscripten::val>("transpose", reshaped_query, options);

Expand All @@ -155,13 +195,14 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val key_for_scatter = model_builder.GetBuilder().call<emscripten::val>("reshape", key_input, emscripten::val::array(reshape_kv_shape), common_options);

common_options.set("label", node.Name() + "/GQA/value/reshape_1");
emscripten::val value_for_scatter = model_builder.GetBuilder().call<emscripten::val>("reshape", value_input, emscripten::val::array(reshape_kv_shape), common_options);
emscripten::val value_for_scatter = model_builder.GetBuilder().call<emscripten::val>("reshape", value_input, emscripten::val::array(reshape_kv_shape), common_options);

common_options.set("label", node.Name() + "seqlens_k_casted");
emscripten::val seqlens_k_casted = model_builder.GetBuilder().call<emscripten::val>("cast", seqlens_k_input, emscripten::val("int64"), common_options);


std::vector<uint8_t> first_condition({(qkv_sequence_length>1)});
// The prefill and decode stages require different index construction for ScatterND operations.
// Similar to other EPs like CPU and DirectML, when qkv_sequence_length > 1, the key and value are scattered to the beginning of kv cache.
std::vector<uint8_t> first_condition({(qkv_sequence_length > 1)});
emscripten::val desc = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8), "Unsupported data type");
emscripten::val dims_condition = emscripten::val::array(std::vector<uint32_t>({1}));
Expand Down Expand Up @@ -201,7 +242,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
common_options.set("label", node.Name() + "/GQA/present_value/ScatterND");
emscripten::val present_value = model_builder.GetBuilder().call<emscripten::val>("scatterND", past_value_input, scatter_indices_casted, value_for_scatter, common_options);

options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0,1,3,2})));
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0, 1, 3, 2})));
options.set("label", node.Name() + "/GQA/present_key/transpose");
emscripten::val true_present_key = model_builder.GetBuilder().call<emscripten::val>("transpose", present_key, options);

Expand All @@ -221,7 +262,7 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val div_output = model_builder.GetBuilder().call<emscripten::val>("div", matmul_output, scale_constant, common_options);

// static_cast<int64_t>(qkv_batch_size), static_cast<int64_t>(num_heads), static_cast<int64_t>(qkv_sequence_length), static_cast<int64_t>(past_sequence_length)
std::vector<int64_t> mask_shape_ones_shape(qkv_batch_size*num_heads*qkv_sequence_length*past_sequence_length, 1);
std::vector<int64_t> mask_shape_ones_shape(qkv_batch_size * num_heads * qkv_sequence_length * past_sequence_length, 1);
emscripten::val desc_mask_shape_ones_shape = emscripten::val::object();
ORT_RETURN_IF_NOT(SetWebnnDataType(desc_mask_shape_ones_shape, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type");
emscripten::val dims_mask_shape = emscripten::val::array(std::vector<uint32_t>({qkv_batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
Expand Down Expand Up @@ -298,7 +339,6 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
common_options.set("label", node.Name() + "/GQA/qkv/reshape");
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("reshape", transposed_attn_output, emscripten::val::array(reshape_output_shape), common_options);


if (node.OutputDefs()[0]->Type() == onnx::Utils::DataTypeUtils::ToType("float16")) {
common_options.set("label", node.Name() + "/GQA/postprocess/cast/output");
output = model_builder.GetBuilder().call<emscripten::val>("cast", output, emscripten::val("float16"), common_options);
Expand All @@ -322,9 +362,9 @@ Status AttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// Operator support related.

bool AttentionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
NodeAttrHelper helper(node);
Expand All @@ -345,50 +385,18 @@ bool AttentionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
}

bool AttentionOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
// int32_t input0_type; // query type
// int32_t input1_type; // key type
// int32_t input2_type; // value type
// int32_t input3_type; // past_key type
// int32_t input4_type; // past_value type
// int32_t input5_type; // seqlens_k type
// int32_t input6_type; // total_sequence_length type


if (input_defs.size() < 7) {
LOGS(logger, VERBOSE) << op_type << " requires at least seven inputs.";
return false;
} else {
LOGS(logger, VERBOSE) << op_type << " has inputs size: " << input_defs.size();
// if (input_defs.size() == 8 && ExistTensor(input_defs[7])) {
// LOGS(logger, VERBOSE) << op_type << " gets unexpected cos_cache tensor.";
// return false;
// } else {
// if (input_defs.size() == 9 && (ExistTensor(input_defs[7]) || ExistTensor(input_defs[8]))) {
// LOGS(logger, VERBOSE) << op_type << " gets unexpected cos_cache / sin_cache tensor.";
// return false;
// }
// }
}

// if (!GetType(*input_defs[0], input0_type, logger) ||
// !GetType(*input_defs[1], input1_type, logger) ||
// !GetType(*input_defs[2], input2_type, logger) ||
// !GetType(*input_defs[3], input3_type, logger) ||
// !GetType(*input_defs[4], input4_type, logger) ||
// !GetType(*input_defs[5], input5_type, logger) ||
// !GetType(*input_defs[6], input6_type, logger)) {
// return false;
// }
// std::vector<int32_t> input_types = {input0_type, input1_type, input2_type, input3_type, input4_type, input5_type, input6_type};

// if (!AreInputDataTypesSame(op_type, input_types, logger)) {
// return false;
// }

return true;
}

Expand Down

0 comments on commit 03c0a1e

Please sign in to comment.