Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1688,17 +1688,33 @@ /// Trait to check if T provides a 'fold' method for a single result op. template using has_single_result_fold_t = - decltype(std::declval().fold(std::declval())); + decltype(std::declval().fold(std::declval>())); template constexpr static bool has_single_result_fold_v = llvm::is_detected::value; /// Trait to check if T provides a general 'fold' method. template using has_fold_t = decltype(std::declval().fold( - std::declval(), + std::declval>(), std::declval &>())); template constexpr static bool has_fold_v = llvm::is_detected::value; + /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a + /// single result op. + template + using has_fold_adaptor_single_result_fold_t = + decltype(std::declval().fold(std::declval())); + template + constexpr static bool has_fold_adaptor_single_result_v = + llvm::is_detected::value; + /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. + template + using has_fold_adaptor_fold_t = decltype(std::declval().fold( + std::declval(), + std::declval &>())); + template + constexpr static bool has_fold_adaptor_v = + llvm::is_detected::value; /// Trait to check if T provides a 'print' method. template @@ -1748,13 +1764,14 @@ // If the operation is single result and defines a `fold` method. if constexpr (llvm::is_one_of, Traits...>::value && - has_single_result_fold_v) + (has_single_result_fold_v || + has_fold_adaptor_single_result_v)) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldSingleResultHook(op, operands, results); }; // The operation is not single result and defines a `fold` method. - if constexpr (has_fold_v) + if constexpr (has_fold_v || has_fold_adaptor_v) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldHook(op, operands, results); @@ -1773,9 +1790,12 @@ static LogicalResult foldSingleResultHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - OpFoldResult result = - cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), op->getRegions())); + OpFoldResult result; + if constexpr (has_fold_adaptor_single_result_v) + result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( + operands, op->getAttrDictionary(), op->getRegions())); + else + result = cast(op).fold(operands); // If the fold failed or was in-place, try to fold the traits of the // operation. @@ -1792,10 +1812,15 @@ template static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - LogicalResult result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), - op->getRegions()), - results); + auto result = LogicalResult::failure(); + if constexpr (has_fold_adaptor_v) { + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), + op->getRegions()), + results); + } else { + result = cast(op).fold(operands, results); + } // If the fold failed or was in-place, try to fold the traits of the // operation. Index: mlir/test/IR/test-manual-cpp-fold.mlir =================================================================== --- /dev/null +++ mlir/test/IR/test-manual-cpp-fold.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func.func @test() -> i32 { + %c5 = "test.constant"() {value = 5 : i32} : () -> i32 + %res = "test.manual_cpp_op_with_fold"(%c5) : (i32) -> i32 + return %res : i32 +} + +// CHECK-LABEL: func.func @test +// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32} +// CHECK-NEXT: return %[[C]] Index: mlir/test/lib/Dialect/Test/TestDialect.h =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.h +++ mlir/test/lib/Dialect/Test/TestDialect.h @@ -58,6 +58,23 @@ #include "TestOps.h.inc" namespace test { + +// Op deliberately defined in C++ code rather than ODS to test that C++ +// Ops can still use the old `fold` method. +class ManualCppOpWithFold + : public mlir::Op { +public: + using Op::Op; + + static llvm::StringRef getOperationName() { + return "test.manual_cpp_op_with_fold"; + } + + static llvm::ArrayRef getAttributeNames() { return {}; } + + mlir::OpFoldResult fold(llvm::ArrayRef attributes); +}; + void registerTestDialect(::mlir::DialectRegistry ®istry); void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns); } // namespace test Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -358,6 +358,7 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + addOperations(); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); @@ -1634,6 +1635,14 @@ setResultRanges(getResult(), range); } +OpFoldResult ManualCppOpWithFold::fold(ArrayRef attributes) { + // Just a simple fold for testing purposes that reads an operands constant + // value and returns it. + if (!attributes.empty()) + return attributes.front(); + return nullptr; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc"