Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify flag parsing in gather_borg_symbols.cc using fixed_option_set_flag. #22477

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/tsl/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Tensor Standard Libraries - common utilities for implementing XLA.

load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
Expand Down
23 changes: 22 additions & 1 deletion xla/tsl/util/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Description:
# Tensor Standard Libraries.
# Tensor Standard Libraries - common utilities for implementing XLA.
#
# The libraries in this package are not allowed to have ANY dependencies
# to other TF components outside of TSL.
Expand Down Expand Up @@ -345,3 +345,24 @@ filegroup(
"//tensorflow/core/util:__pkg__",
]),
)

cc_library(
name = "fixed_option_set_flag",
srcs = ["fixed_option_set_flag.cc"],
hdrs = ["fixed_option_set_flag.h"],
deps = [
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
],
)

tsl_cc_test(
name = "fixed_option_set_flag_test",
srcs = ["fixed_option_set_flag_test.cc"],
deps = [
":fixed_option_set_flag",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)
17 changes: 17 additions & 0 deletions xla/tsl/util/fixed_option_set_flag.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Ensure that the header is self-contained.
#include "xla/tsl/util/fixed_option_set_flag.h"
160 changes: 160 additions & 0 deletions xla/tsl/util/fixed_option_set_flag.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_TSL_UTIL_FIXED_OPTION_SET_FLAG_H_
#define XLA_TSL_UTIL_FIXED_OPTION_SET_FLAG_H_

#include <string>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"

namespace xla {

// A parser for a flag of type T that takes a fixed set of options. This makes
// it easier and safer to define flags that take a fixed set of options.
// Requires T to support equality comparison, hashing, and conversion to
// std::string via absl::StrCat.
//
// Example usage:
//
// enum class Foo {
// kBar,
// kBaz,
// };
//
// static const FixedOptionSetFlagParser<Foo>& GetFooParser() {
// static const auto& parser = GetFixedOptionSetFlagParser<Foo>({
// {"bar", Foo::kBar, "Optional description of bar."},
// {"baz", Foo::kBaz, "Optional description of baz."},
// });
// return parser;
// }
//
// bool AbslParseFlag(absl::string_view text, Foo* foo, std::string* error) {
// return GetFooParser().Parse(text, foo, error);
// }
//
// std::string AbslUnparseFlag(Foo foo) { return GetFooParser().Unparse(foo); }
//
// Compared with implementing AbslParseFlag and AbslUnparseFlag manually, this
// class provides the following benefits:
//
// - We only need to define the mapping between options and values once, and
// the two directions are guaranteed to be consistent.
// - The parser validates the flag options, so it's impossible to have
// duplicate names or values in the mapping.
//
// This class is thread-safe.
template <typename T>
class FixedOptionSetFlagParser {
public:
// Stores the name, value, and description of one option of a flag of type T.
struct FlagOption {
std::string name;
T value;
std::string description;
};

// Creates a parser for a flag of type T that takes a fixed set of options.
// The options must be valid, i.e., there must be no duplicate names or
// values.
explicit FixedOptionSetFlagParser(const std::vector<FlagOption>& options)
: options_(ValidateFlagOptionsOrDie(options)) {}

// Parses the flag from the given text. Returns true if the text is
// valid, and sets the value to the corresponding option. Otherwise, returns
// false and sets the error message.
[[nodiscard]] bool Parse(absl::string_view text, T* value,
std::string* error) const {
for (const auto& option : options_) {
if (text == option.name) {
*value = option.value;
return true;
}
}
*error = absl::StrCat(
"Unrecognized flag option: ", text, ". Valid options are: ",
absl::StrJoin(options_, ", ",
[](std::string* out, const FlagOption& option) {
absl::StrAppend(out, option.name);
if (!option.description.empty()) {
absl::StrAppend(out, " (", option.description, ")");
}
}),
".");
return false;
}

// Unparses the flag value to the corresponding option name. If the value is
// not one of the options, returns the string representation of the value.
[[nodiscard]] std::string Unparse(const T& value) const {
for (const auto& option : options_) {
if (option.value == value) {
return std::string(option.name);
}
}
return absl::StrCat(value);
}

private:
// Validates the flag options and returns them. Dies if the options are not
// valid.
static std::vector<FlagOption> ValidateFlagOptionsOrDie(
const std::vector<FlagOption>& options) {
// Check that the same name or value is not used multiple times.
absl::flat_hash_set<std::string> names;
absl::flat_hash_set<T> values;
for (const auto& option : options) {
CHECK(!names.contains(option.name))
<< "Duplicate flag option name: " << option.name;
CHECK(!values.contains(option.value))
<< "Duplicate flag option value: " << absl::StrCat(option.value);
names.insert(option.name);
values.insert(option.value);
}
return options;
}

const std::vector<FlagOption> options_;
};

// Returns the parser for a flag of type T that takes a fixed set of options.
// The options must be valid, i.e., there must be no duplicate names or values.
// The returned parser is guaranteed to be alive for the lifetime of the
// program.
//
// For each T, the caller must call this function exactly once to get the
// parser, and then use the parser to define the AbslParseFlag and
// AbslUnparseFlag functions for T.
template <typename T>
[[nodiscard]] const FixedOptionSetFlagParser<T>& GetFixedOptionSetFlagParser(
const std::vector<typename FixedOptionSetFlagParser<T>::FlagOption>&
options) {
// Per Google C++ style guide, we use a function-local static
// variable to ensure that the parser is only created once and never
// destroyed. We cannot use absl::NoDestructor here because it is not
// available in the version of Abseil that openxla uses.
static const auto* const parser = new FixedOptionSetFlagParser<T>(options);
return *parser;
}

} // namespace xla

#endif // XLA_TSL_UTIL_FIXED_OPTION_SET_FLAG_H_
75 changes: 75 additions & 0 deletions xla/tsl/util/fixed_option_set_flag_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Unit tests for FixedOptionSetFlag.

#include "xla/tsl/util/fixed_option_set_flag.h"

#include <string>

#include <gtest/gtest.h>
#include "absl/strings/string_view.h"

namespace xla {
namespace {

enum class Foo {
kBar,
kBaz,
};

static const FixedOptionSetFlagParser<Foo>& GetFooParser() {
static const auto& parser = GetFixedOptionSetFlagParser<Foo>({
{"bar", Foo::kBar, "the first option"},
{"baz", Foo::kBaz},
});
return parser;
};

bool AbslParseFlag(absl::string_view text, Foo* foo, std::string* error) {
return GetFooParser().Parse(text, foo, error);
}

std::string AbslUnparseFlag(Foo foo) { return GetFooParser().Unparse(foo); }

TEST(FixedOptionSetFlag, ParseSucceedsForValidOptions) {
Foo foo;
std::string error;
ASSERT_TRUE(AbslParseFlag("bar", &foo, &error));
EXPECT_EQ(foo, Foo::kBar);
ASSERT_TRUE(AbslParseFlag("baz", &foo, &error));
EXPECT_EQ(foo, Foo::kBaz);
}

TEST(FixedOptionSetFlag, ParseFailsForInvalidOptions) {
Foo foo;
std::string error;
ASSERT_FALSE(AbslParseFlag("foo", &foo, &error));
EXPECT_EQ(error,
"Unrecognized flag option: foo. Valid options are: bar (the first "
"option), baz.");
}

TEST(FixedOptionSetFlag, UnparseSucceedsForValidOptions) {
EXPECT_EQ(AbslUnparseFlag(Foo::kBar), "bar");
EXPECT_EQ(AbslUnparseFlag(Foo::kBaz), "baz");
}

TEST(FixedOptionSetFlag, UnparseFailsForInvalidOptions) {
EXPECT_EQ(AbslUnparseFlag(static_cast<Foo>(123)), "123");
}

} // namespace
} // namespace xla