diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -460,6 +460,39 @@ //===----------------------------------------------------------------------===// namespace { + +// Merge multiple `shape.assuming_all` operations together. +// +// %0 = shape.assuming_all %w0, %w1 +// %1 = shape.assuming_all %w2, %0 +// +// to: +// +// %0 = shape.assuming_all %w0, %w2, %w2 +struct MergeAssumingAllOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingAllOp op, + PatternRewriter &rewriter) const override { + SmallVector operands; + + for (Value operand : op.getInputs()) { + if (auto assume_all = operand.getDefiningOp()) + operands.append(assume_all.operand_begin(), assume_all->operand_end()); + else + operands.push_back(operand); + } + + // We didn't find any other `assuming_all` ops to merge with. + if (operands.size() == op.getNumOperands()) + return failure(); + + // Replace with a new `assuming_all` operation with merged constraints. + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; + struct AssumingAllToCstrEqCanonicalization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -506,7 +539,8 @@ void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -463,6 +463,26 @@ return } +// ----- + +// merge assuming_all operations +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: %[[W0:.*]] = "test.source" + // CHECK-NEXT: %[[W1:.*]] = "test.source" + // CHECK-NEXT: %[[W2:.*]] = "test.source" + // CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]] + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = "test.source"() : () -> !shape.witness + %1 = "test.source"() : () -> !shape.witness + %2 = "test.source"() : () -> !shape.witness + %3 = shape.assuming_all %0, %1 + %4 = shape.assuming_all %3, %2 + "consume.witness"(%4) : (!shape.witness) -> () + return +} + // ----- // `assuming_all` with all `cstr_eq` and shared operands can be collapsed. // CHECK-LABEL: func @assuming_all_to_cstr_eq