diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -131,6 +131,12 @@ return failure(); } + /// Returns true if this pattern is known to result in recursive application, + /// i.e. this pattern may generate IR that also matches this pattern, but is + /// known to bound the recursion. This signals to a rewriter that it is safe + /// to apply this pattern recursively to generated IR. + virtual bool hasBoundedRewriteRecursion() const { return false; } + /// Return a list of operations that may be generated when rewriting an /// operation instance with this pattern. ArrayRef getGeneratedOps() const { return generatedOps; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -789,23 +789,10 @@ Value extractedDest = extractOne(rewriter, loc, op.dest(), off); // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. - InsertStridedSliceOp insertStridedSliceOp = - rewriter.create( - loc, extractedSource, extractedDest, - getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - // Call matchAndRewrite recursively from within the pattern. This - // circumvents the current limitation that a given pattern cannot - // be called multiple times by the PatternRewrite infrastructure (to - // avoid infinite recursion, but in this case, infinite recursion - // cannot happen because the rank is strictly decreasing). - // TODO(rriddle, nicolasvasilache) Implement something like a hook for - // a potential function that must decrease and allow the same pattern - // multiple times. - auto success = matchAndRewrite(insertStridedSliceOp, rewriter); - (void)success; - assert(succeeded(success) && "Unexpected failure"); - extractedSource = insertStridedSliceOp; + extractedSource = rewriter.create( + loc, extractedSource, extractedDest, + getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); } // 4. Insert the extractedSource into the res vector. res = insertOne(rewriter, loc, extractedSource, res, off); @@ -814,6 +801,9 @@ rewriter.replaceOp(op, res); return success(); } + /// This pattern creates recursive InsertStridedSliceOp, but the recursion is + /// bounded as the rank is strictly decreasing. + bool hasBoundedRewriteRecursion() const final { return true; } }; class VectorTypeCastOpConversion : public ConvertToLLVMPattern { @@ -1068,28 +1058,19 @@ off += stride, ++idx) { Value extracted = extractOne(rewriter, loc, op.vector(), off); if (op.offsets().getValue().size() > 1) { - StridedSliceOp stridedSliceOp = rewriter.create( + extracted = rewriter.create( loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), getI64SubArray(op.sizes(), /* dropFront=*/1), getI64SubArray(op.strides(), /* dropFront=*/1)); - // Call matchAndRewrite recursively from within the pattern. This - // circumvents the current limitation that a given pattern cannot - // be called multiple times by the PatternRewrite infrastructure (to - // avoid infinite recursion, but in this case, infinite recursion - // cannot happen because the rank is strictly decreasing). - // TODO(rriddle, nicolasvasilache) Implement something like a hook for - // a potential function that must decrease and allow the same pattern - // multiple times. - auto success = matchAndRewrite(stridedSliceOp, rewriter); - (void)success; - assert(succeeded(success) && "Unexpected failure"); - extracted = stridedSliceOp; } res = insertOne(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, {res}); return success(); } + /// This pattern creates recursive StridedSliceOp, but the recursion is + /// bounded as the rank is strictly decreasing. + bool hasBoundedRewriteRecursion() const final { return true; } }; } // namespace diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1256,10 +1256,9 @@ }); // Ensure that we don't cycle by not allowing the same pattern to be - // applied twice in the same recursion stack. - // TODO(riverriddle) We could eventually converge, but that requires more - // complicated analysis. - if (!appliedPatterns.insert(pattern).second) { + // applied twice in the same recursion stack if it is not known to be safe. + if (!pattern->hasBoundedRewriteRecursion() && + !appliedPatterns.insert(pattern).second) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied")); return failure(); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -143,6 +143,13 @@ return } +// CHECK-LABEL: @bounded_recursion +func @bounded_recursion() { + // CHECK: test.recursive_rewrite 0 + test.recursive_rewrite 3 + return +} + // ----- func @fail_to_convert_illegal_op() -> i32 { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1061,6 +1061,12 @@ Arguments<(ins AnyType)>, Results<(outs AnyType)>; def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>; +// Check that patterns can specify bounded recursion when rewriting. +def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> { + let arguments = (ins I64Attr:$depth); + let assemblyFormat = "$depth attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Type Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -360,6 +360,28 @@ return success(); } }; + +//===----------------------------------------------------------------------===// +// Recursive Rewrite Testing +/// This pattern is applied to the same operation multiple times, but has a +/// bounded recursion. +struct TestBoundedRecursiveRewrite + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, + PatternRewriter &rewriter) const final { + // Decrement the depth of the op in-place. + rewriter.updateRootInPlace(op, [&] { + op.setAttr("depth", + rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1)); + }); + return success(); + } + + /// The conversion target handles bounding the recursion of this pattern. + bool hasBoundedRewriteRecursion() const final { return true; } +}; } // namespace namespace { @@ -414,7 +436,7 @@ TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType, TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, - TestNonRootReplacement>(&getContext()); + TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); @@ -449,6 +471,10 @@ op->getAttrOfType("test.recursively_legal")); }); + // Mark the bound recursion operation as dynamically legal. + target.addDynamicallyLegalOp( + [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); + // Handle a partial conversion. if (mode == ConversionMode::Partial) { (void)applyPartialConversion(getOperation(), target, patterns,