Skip to content

Commit

Permalink
Improve readability of hlo_casting_utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722514312
  • Loading branch information
toli-y authored and Google-ML-Automation committed Feb 3, 2025
1 parent 0a5c792 commit 12e7ace
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions xla/hlo/ir/hlo_casting_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.
#ifndef XLA_HLO_IR_HLO_CASTING_UTILS_H_
#define XLA_HLO_IR_HLO_CASTING_UTILS_H_

#include <string>
#include <type_traits>

#include "absl/strings/str_cat.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/tsl/platform/logging.h"

Expand All @@ -29,31 +31,28 @@ template <class T>
using EnableIfDerivedFromHlo =
typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;

template <class T>
std::string InvalidMsg(const HloInstruction* instruction) {
return absl::StrCat("Invalid HloInstruction casting. Destination type: ",
typeid(T).name(), ". Instruction: ", instruction->name());
}

// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
// nullptr or runtime information does not match.
//
// Similar to LLVM's cast.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
const T* Cast(const HloInstruction* instruction) {
CHECK(instruction != nullptr);
CHECK(T::ClassOf(instruction))
<< "Invalid HloInstruction casting. Destination type: "
<< typeid(T).name() << ". Instruction: " << instruction->name();
const T* casted = static_cast<const T*>(instruction);
#ifndef NDEBUG
const T* dynamic_casted = dynamic_cast<const T*>(instruction);
CHECK(dynamic_casted != nullptr)
<< "Invalid HloInstruction casting. Destination type: "
<< typeid(T).name() << ". Instruction: " << instruction->name();
#endif
return casted;
const T* Cast(const HloInstruction* instr) {
CHECK(instr != nullptr);
CHECK(T::ClassOf(instr)) << InvalidMsg<T>(instr);
DCHECK(dynamic_cast<const T*>(instr) != nullptr) << InvalidMsg<T>(instr);
return static_cast<const T*>(instr);
}

// Non-const overload of Cast.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
T* Cast(HloInstruction* instruction) {
return const_cast<T*>(
Cast<T>(const_cast<const HloInstruction*>(instruction)));
T* Cast(HloInstruction* instr) {
return const_cast<T*>(Cast<T>(const_cast<const HloInstruction*>(instr)));
}

// Works just like the Cast, except that it allows for a null pointer as an
Expand All @@ -68,30 +67,25 @@ const T* CastOrNull(const HloInstruction* instruction) {
// Non-const overload of CastOrNull.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
T* CastOrNull(HloInstruction* instruction) {
return const_cast<T*>(
CastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
return instruction != nullptr ? Cast<T>(instruction) : nullptr;
}

// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
// nullptr, returns nullptr if runtime information does not match.
//
// Similar to LLVM's dyn_cast.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
const T* DynCast(const HloInstruction* instruction) {
CHECK(instruction != nullptr);
const T* casted =
T::ClassOf(instruction) ? static_cast<const T*>(instruction) : nullptr;
#ifndef NDEBUG
CHECK_EQ(casted, dynamic_cast<const T*>(instruction));
#endif
const T* DynCast(const HloInstruction* instr) {
CHECK(instr != nullptr);
const T* casted = T::ClassOf(instr) ? static_cast<const T*>(instr) : nullptr;
DCHECK_EQ(casted, dynamic_cast<const T*>(instr)) << InvalidMsg<T>(instr);
return casted;
}

// Non-const overload of DynCast.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
T* DynCast(HloInstruction* instruction) {
return const_cast<T*>(
DynCast<T>(const_cast<const HloInstruction*>(instruction)));
T* DynCast(HloInstruction* instr) {
return const_cast<T*>(DynCast<T>(const_cast<const HloInstruction*>(instr)));
}

// Works just like the DynCast, except that it allows for a null pointer as an
Expand All @@ -106,8 +100,7 @@ const T* DynCastOrNull(const HloInstruction* instruction) {
// Non-const overload of DynCastOrNull.
template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
T* DynCastOrNull(HloInstruction* instruction) {
return const_cast<T*>(
DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
return instruction != nullptr ? DynCast<T>(instruction) : nullptr;
}

} // namespace xla
Expand Down

0 comments on commit 12e7ace

Please sign in to comment.