diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" @@ -493,6 +494,97 @@ } }; +// Rewrite `assuming_all` operation with `cstr_broadcastable` operands, to an +// `assuming_all` operation with a smaller number of broadcastable constraints. +// +// %0 = shape.cstr_broadcastable %shape0, %shape1 +// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 +// +// %2 = shape.cstr_broadcastable %shape3, %shape4 +// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 +// +// %4 = shape.assuming_all %0, %1, %2, %3 +// +// to: +// +// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 +// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 +// %2 = shape.assuming_all %0, %1 +// +// In this example if shapes [0, 1, 2] are broadcastable, then it means that +// shapes [0, 1] are broadcastable too, and can be removed from the list of +// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't +// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). +struct AssumingAllOfCstrBroadcastable : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingAllOp op, + PatternRewriter &rewriter) const override { + // Collect all `CstrBroadcastableOp` operands first. + SetVector operands; + for (Value operand : op.getInputs()) { + auto broadcastable = operand.getDefiningOp(); + if (!broadcastable) + return failure(); + + operands.insert(broadcastable); + } + + // Skip trivial `assuming_all` operations. + if (operands.size() <= 1) + return failure(); + + // Collect shapes checked by `cstr_broadcastable` operands. + SmallVector>> shapes; + for (auto cstr : operands) { + DenseSet shapes_set(cstr->operand_begin(), cstr->operand_end()); + shapes.emplace_back(cstr, std::move(shapes_set)); + } + + // Sort by the number of shape operands (larger to smaller). + llvm::sort(shapes, [](auto a, auto b) { + return a.first.getNumOperands() > b.first.getNumOperands(); + }); + + // We start from the `cst_broadcastable` operations with largest number of + // shape operands, and remove redundant `cst_broadcastable` operations. We + // do this until we find a set of `cst_broadcastable` operations with + // non-overlapping constraints. + SmallVector marked_for_erase; + + for (unsigned i = 0; i < shapes.size(); ++i) { + auto isSubset = [&](auto pair) { + return llvm::set_is_subset(pair.second, shapes[i].second); + }; + + // Keep redundant `cstr_broadcastable` operations to be erased. + auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); + for (auto *it0 = it; it0 < shapes.end(); ++it0) + marked_for_erase.push_back(it0->first); + shapes.erase(it, shapes.end()); + } + + // We didn't find any operands that could be removed. + if (marked_for_erase.empty()) + return failure(); + + // Collect non-overlapping `cst_broadcastable` constraints. + SmallVector unique_constraints; + for (auto &shape : shapes) + unique_constraints.push_back(shape.first.getResult()); + + // Replace with a new `assuming_all` operation ... + rewriter.replaceOpWithNewOp(op, unique_constraints); + + // ... and maybe erase `cstr_broadcastable` ops without uses. + for (auto &op : marked_for_erase) + if (op->use_empty()) + rewriter.eraseOp(op); + + return success(); + } +}; + struct AssumingAllToCstrEqCanonicalization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -539,9 +631,10 @@ void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns + .add>(context); } OpFoldResult AssumingAllOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -565,6 +565,46 @@ // ----- +// merge cstr_broadcastable operations +// +// CHECK-LABEL: func @f +// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape +// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape +// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) { + // CHECK-NEXT: shape.cstr_broadcastable %[[ARG0]], %[[ARG1]], %[[ARG2]] + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + %1 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : !shape.shape, !shape.shape, !shape.shape + %2 = shape.assuming_all %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +} + +// ----- + +// do not merge cstr_broadcastable operations +// +// CHECK-LABEL: func @f +// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape +// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape +// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) { + // CHECK-NEXT: shape.cstr_broadcastable %[[ARG0]], %[[ARG1]] + // CHECK-NEXT: shape.cstr_broadcastable %[[ARG1]], %[[ARG2]] + // CHECK-NEXT: shape.assuming_all + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + %1 = shape.cstr_broadcastable %arg1, %arg2 : !shape.shape, !shape.shape + %2 = shape.assuming_all %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +} + +// ----- + // any can be replaced with a constant input if it has one. // CHECK-LABEL: func @f func @f(%arg : !shape.shape) -> !shape.shape {