diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -509,6 +509,8 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + + let hasCanonicalizer = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -159,6 +159,44 @@ p.printOptionalAttrDict(op.getAttrs()); } +namespace { +// Removes AssumingOp with a passing witness and inlines the region. +struct AssumingWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + auto witness = op.witness().getDefiningOp(); + if (!witness || !witness.passingAttr()) + return failure(); + + auto *blockBeforeAssuming = rewriter.getInsertionBlock(); + auto *assumingBlock = op.getBody(); + auto initPosition = rewriter.getInsertionPoint(); + auto *blockAfterAssuming = + rewriter.splitBlock(blockBeforeAssuming, initPosition); + + // Remove the AssumingOp and AssumingYieldOp. + auto &yieldOp = assumingBlock->back(); + rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); + rewriter.replaceOp(op, yieldOp.getOperands()); + rewriter.eraseOp(&yieldOp); + + // Merge blocks together as there was no branching behavior from the + // AssumingOp. + rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); + rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); + return success(); + } +}; +}; // namespace + +void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If taking a passing witness, inline region + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -325,6 +325,42 @@ } // ----- +// assuming with a known passing witness can be removed +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: source + // CHECK-NEXT: sink + // CHECK-NEXT: return + %0 = shape.const_witness true + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + "test.sink"(%1) : (index) -> () + return +} + +// ----- +// assuming without a known passing passing witness cannot be removed +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: test.source + // CHECK-NEXT: shape.assuming + // CHECK-NEXT: test.source + // CHECK-NEXT: shape.assuming_yield + // CHECK-NEXT: } + // CHECK-NEXT: test.sink + // CHECK-NEXT: return + %0 = "test.source"() : () -> (!shape.witness) + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + "test.sink"(%1) : (index) -> () + return +} + +// ----- // Broadcastable with broadcastable constant shapes can be removed. // CHECK-LABEL: func @f func @f() {