diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -77,14 +77,14 @@ void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { // The op needs to be inserted only if the fold (below) fails, or the number - // of results of the op is zero (which is treated as an in-place - // fold). Using create methods of the builder will insert the op, so not - // using it here. + // of results produced by the successful folding is zero (which is treated + // as an in-place fold). Using create methods of the builder will insert the + // op, so not using it here. OperationState state(location, OpTy::getOperationName()); OpTy::build(builder, state, std::forward(args)...); Operation *op = Operation::create(state); - if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) { + if (failed(tryToFold(builder, op, results)) || results.empty()) { builder.insert(op); results.assign(op->result_begin(), op->result_end()); return; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -323,6 +323,15 @@ return success(); } +OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { + assert(operands.size() == 1); + if (operands.front()) { + setAttr("attr", operands.front()); + return getResult(); + } + return {}; +} + LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, ArrayRef attributes, RegionRange regions, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -734,6 +734,17 @@ let results = (outs I32); } +def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { + let arguments = (ins I32); + let results = (outs I32); +} + +def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> { + let arguments = (ins I32:$op, I32Attr:$attr); + let results = (outs I32); + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -8,9 +8,11 @@ #include "TestDialect.h" #include "mlir/Conversion/StandardToStandard/StandardToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/FoldUtils.h" using namespace mlir; @@ -39,13 +41,36 @@ //===----------------------------------------------------------------------===// namespace { +struct FoldingPattern : public RewritePattern { +public: + FoldingPattern(MLIRContext *context) + : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), + /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Exercice OperationFolder API for a single-result operation that is folded + // upon construction. The operation being created through the folder has an + // in-place folder, and it should be still present in the output. + // Furthermore, the folder should not crash when attempting to recover the + // (unchanged) opeation result. + OperationFolder folder(op->getContext()); + Value result = folder.create( + rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), + rewriter.getI32IntegerAttr(0)); + assert(result); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct TestPatternDriver : public PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); // Verify named pattern is generated with expected name. - patterns.insert(&getContext()); + patterns.insert(&getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); }