diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -441,6 +441,8 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + + let hasCanonicalizer = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -164,6 +164,41 @@ p.printOptionalAttrDict(op.getAttrs()); } +namespace { +// Removes AssumingOp with a passing witness and inlines the region. +struct AssumingWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + if (op.witness().getDefiningOp()) { + auto *firstBlock = rewriter.getInsertionBlock(); + auto *secondBlock = op.getBody(); + auto initPosition = rewriter.getInsertionPoint(); + auto *thirdBlock = rewriter.splitBlock(firstBlock, initPosition); + + // Remove the AssumingOp and AssumingYieldOp. + auto &yieldOp = secondBlock->back(); + rewriter.inlineRegionBefore(op.doRegion(), thirdBlock); + rewriter.replaceOp(op, yieldOp.getOperands()); + rewriter.eraseOp(&yieldOp); + + // Merge blocks together as there was no branching behavior from the + // AssumingOp. + rewriter.mergeBlocks(secondBlock, firstBlock); + rewriter.mergeBlocks(thirdBlock, firstBlock); + return success(); + } + return failure(); + } +}; +}; // namespace + +void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If taking true witness, inline region + patterns.insert(context); +} //===----------------------------------------------------------------------===// // AssumingAllOp diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -299,6 +299,42 @@ return %1 : !shape.shape } +// ----- +// assuming with a known true witness can be removed +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: source + // CHECK-NEXT: sink + // CHECK-NEXT: return + %0 = shape.true_witness + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + "test.sink"(%1) : (index) -> () + return +} + +// ----- +// assuming without a known true witness cannot be removed +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: test.source + // CHECK-NEXT: shape.assuming + // CHECK-NEXT: test.source + // CHECK-NEXT: shape.assuming_yield + // CHECK-NEXT: } + // CHECK-NEXT: test.sink + // CHECK-NEXT: return + %0 = "test.source"() : () -> (!shape.witness) + %1 = shape.assuming %0 -> index { + %2 = "test.source"() : () -> (index) + shape.assuming_yield %2 : index + } + "test.sink"(%1) : (index) -> () + return +} + // ----- // Broadcastable with broadcastable constant shapes can be removed. // CHECK-LABEL: func @f