diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -470,6 +470,16 @@ template <> struct DenseMapInfo : public DenseMapInfo {}; + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + } // end namespace llvm #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -236,4 +236,17 @@ } // namespace LLVM } // namespace mlir +namespace llvm { + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + #endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -269,15 +269,35 @@ void dump() const { llvm::errs() << *this << "\n"; } }; +// Temporarily exit the MLIR namespace to add casting support as later code in +// this uses it. The CastInfo must come after the OpFoldResult definition and +// before any cast function calls depending on CastInfo. + +} // namespace mlir + +namespace llvm { + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + +namespace mlir { + /// Allow printing to a stream. inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { - if (Value value = ofr.dyn_cast()) + if (Value value = llvm::dyn_cast_if_present(ofr)) value.print(os); else - ofr.dyn_cast().print(os); + llvm::dyn_cast_if_present(ofr).print(os); return os; } - /// Allow printing to a stream. inline raw_ostream &operator<<(raw_ostream &os, OpState op) { op.print(os, OpPrintingFlags().useLocalScope()); @@ -1554,7 +1574,7 @@ return failure(); if (OpFoldResult result = Trait::foldTrait(op, operands)) { - if (result.template dyn_cast() != op->getResult(0)) + if (llvm::dyn_cast_if_present(result) != op->getResult(0)) results.push_back(result); return success(); } @@ -1902,7 +1922,8 @@ // If the fold failed or was in-place, try to fold the traits of the // operation. - if (!result || result.template dyn_cast() == op->getResult(0)) { + if (!result || + llvm::dyn_cast_if_present(result) == op->getResult(0)) { if (succeeded(op_definition_impl::foldTraits...>( op, operands, results))) return success(); @@ -2117,7 +2138,6 @@ } static bool isEqual(T lhs, T rhs) { return lhs == rhs; } }; - } // namespace llvm #endif diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h --- a/mlir/include/mlir/IR/Unit.h +++ b/mlir/include/mlir/IR/Unit.h @@ -39,4 +39,17 @@ } // end namespace mlir +namespace llvm { + +// Allow llvm::cast style functions. +template +struct CastInfo + : public CastInfo {}; + +template +struct CastInfo + : public CastInfo {}; + +} // namespace llvm + #endif // MLIR_IR_UNIT_H diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -30,10 +30,15 @@ namespace llvm { +// Allow llvm::cast style functions. template struct CastInfo : public CastInfo {}; +template +struct CastInfo + : public CastInfo {}; + } // namespace llvm #endif // MLIR_INTERFACES_CALLINTERFACES_H