diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -381,9 +381,30 @@ // AssumingAllOp //===----------------------------------------------------------------------===// +namespace { +struct AssumingAllToCstrEqCanonicalization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingAllOp op, + PatternRewriter &rewriter) const override { + SmallVector shapes; + for (Value v : op.inputs()) { + auto cstrEqOp = v.getDefiningOp(); + if (!cstrEqOp) + return failure(); + auto range = cstrEqOp.shapes(); + shapes.append(range.begin(), range.end()); + } + rewriter.replaceOpWithNewOp(op, shapes); + return success(); + } +}; +} // namespace + void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } OpFoldResult AssumingAllOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -433,6 +433,20 @@ return } +// ----- +// `assuming_all` with all `cstr_eq` can be collapsed. +// CHECK-LABEL: func @assuming_all_to_cstr_eq +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor, %[[C:.*]]: tensor<3xindex>) +func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor, + %c : tensor<3xindex>) -> !shape.witness { + // CHECK: %[[RESULT:.*]] = shape.cstr_eq %[[A]], %[[B]], %[[B]], %[[C]] + // CHECK: return %[[RESULT]] + %0 = shape.cstr_eq %a, %b : !shape.shape, tensor + %1 = shape.cstr_eq %b, %c : tensor, tensor<3xindex> + %2 = shape.assuming_all %0, %1 + return %2 : !shape.witness +} + // ----- // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f