diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -366,6 +366,7 @@ let results = (outs Shape_ShapeType:$result); let assemblyFormat = "`(` $inputs `)` attr-dict"; + let hasCanonicalizer = 1; } def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> { @@ -393,6 +394,7 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "`(` $inputs `)` attr-dict"; + let hasCanonicalizer = 1; } def Shape_AssumingOp : Shape_Op<"assuming", @@ -413,6 +415,8 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + + let hasCanonicalizer = 1; } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", @@ -453,6 +457,7 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict"; + let hasCanonicalizer = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", [NoSideEffect]> { @@ -472,9 +477,28 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "`(` $inputs `)` attr-dict"; + let hasCanonicalizer = 1; } +// TODO(tpopp): Support witness attributes and then make this ConstantLike. +// Note: This operation might be replaced with a general op that takes a +// True/False Attribute. +def Shape_TrueWitnessOp : Shape_Op<"true_witness", [NoSideEffect]> { + let summary = "An operation that returns a successful witness."; + let description = [{ + %0 = shape.const_shape [1, 2, 3] + %1 = shape.const_shape [1, 2, 3] + %w0 = shape.cstr_eq(%0, %1) // Can be canonicalized to true_witness + %w1 = shape.true_witness + %w2 = shape.assuming_all(%w0, %w2) // Can be canonicalized to true_witness + }]; + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result", + "build(b, result, ::mlir::shape::WitnessType::get(b.getContext()));" + >]; -// Canonicalization patterns. + let assemblyFormat = "attr-dict"; + let results = (outs Shape_WitnessType:$result); +} #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS IR/ShapeCanonicalization.td) +mlir_tablegen(IR/ShapeCanonicalization.inc -gen-rewriters) +add_public_tablegen_target(MLIRShapeCanonicalizationIncGen) + add_mlir_dialect_library(MLIRShape IR/Shape.cpp diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -18,6 +18,9 @@ using namespace mlir; using namespace mlir::shape; +namespace { +#include "IR/ShapeCanonicalization.inc" +} ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -98,6 +101,33 @@ // AnyOp //===----------------------------------------------------------------------===// +namespace { +// Removes AnyOp with constant shape input. +// TODO(tpopp): This case can be replaced with folding. +// TODO(tpopp): Another pattern should be implemented for shapes that can be +// determined through mixtures of the known dimensions of the inputs. +struct AnyWithConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AnyOp op, + PatternRewriter &rewriter) const override { + for (auto shapeOp : op.getOperands()) { + if (shapeOp.getDefiningOp()) { + rewriter.replaceOp(op, shapeOp); + return success(); + } + } + return failure(); + } +}; +}; // namespace + +void AnyOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // Replace with constant shape? Maybe folding? + patterns.insert(context); +} + LogicalResult AnyOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, @@ -108,7 +138,7 @@ } //===----------------------------------------------------------------------===// -// AssumingOp +// Assuming //===----------------------------------------------------------------------===// static ParseResult parseAssumingOp(OpAsmParser &parser, @@ -151,6 +181,51 @@ p.printOptionalAttrDict(op.getAttrs()); } +namespace { +// Removes AssumingOp with a passing condition and inlines the region. +struct AssumingWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + if (op.condition().getDefiningOp()) { + auto *firstBlock = rewriter.getInsertionBlock(); + auto *secondBlock = op.getBody(); + auto initPosition = rewriter.getInsertionPoint(); + auto *thirdBlock = rewriter.splitBlock(firstBlock, initPosition); + + // Remove the AssumingOp and AssumingYieldOp. + auto &yieldOp = secondBlock->back(); + rewriter.inlineRegionBefore(op.doRegion(), thirdBlock); + rewriter.replaceOp(op, yieldOp.getOperands()); + rewriter.eraseOp(&yieldOp); + + // Merge blocks together as there was no branching behavior from the + // AssumingOp. + rewriter.mergeBlocks(secondBlock, firstBlock); + rewriter.mergeBlocks(thirdBlock, firstBlock); + return success(); + } + return failure(); + } +}; +}; // namespace + +void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If taking true witness, inline region + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// AssumingAllOp +//===----------------------------------------------------------------------===// +void AssumingAllOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // If all true, replace with true. Folding? + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -232,6 +307,55 @@ return success(); } +//===----------------------------------------------------------------------===// +// CstrBroadcastableOp +//===----------------------------------------------------------------------===// + +namespace { +// TODO(tpopp): Make this and other witness related operations commutative. +// TODO(tpopp): Add a case for unknown shapes that are still defined by the same +// operation. +// TODO(tpopp): Once Witnesses are Attributes, replace this with folding. +struct CstrBroadcastableTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Don't try to compare equality when the shapes are not constant. + auto lhsShape = op.getOperand(0).getDefiningOp().shape(); + auto rhsShape = op.getOperand(1).getDefiningOp().shape(); + if (!(lhsShape && rhsShape)) + return failure(); + + SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. + if (!OpTrait::util::getBroadcastedShape( + llvm::to_vector<6>(lhsShape.getValues()), + llvm::to_vector<6>(rhsShape.getValues()), resultShape)) + return failure(); + + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; +}; // namespace + +void CstrBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // If equal, return true op + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// CstrEqOp +//===----------------------------------------------------------------------===// + +void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If equal, return true op + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -0,0 +1,26 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +// Constraints +def AllInputsTrueWitnesses : ConstraintgetOperands(), [](mlir::Value op) { + return op.getDefiningOp(); + }) + }]>>; + +def AllInputShapesEq : ConstraintgetOperands(), [&](mlir::Value val) { + return val.getDefiningOp().shape() == + $0.getOwner()->getOperand(0).getDefiningOp().shape(); + }) + }]>>; + +// Canonicalization patterns. +def ConstantAssumingAll : Pat<(Shape_AssumingAllOp:$op $input), + (Shape_TrueWitnessOp), + [(AllInputsTrueWitnesses $op)] >; + +def ConstCstrEq : Pat<(Shape_CstrEqOp:$op $shapes), + (Shape_TrueWitnessOp), + [(AllInputShapesEq $op)] >; + + diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -86,3 +86,38 @@ %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> return %0 : tensor<2xindex> } + +// ----- +// Confirm that cstr_eq, assuming_all, assuming, and any can all be folded away +// when dependent only on constant shapes. +// CHECK-LABEL: func @f +func @f() -> !shape.shape { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: return + %cs0 = shape.const_shape [0, 1] + %cs1 = shape.const_shape [0, 1] + %cs2 = shape.const_shape [0, 1] + %0 = shape.cstr_eq(%cs0, %cs1) + %1 = shape.cstr_eq(%cs0, %cs1, %cs2) + %2 = shape.assuming_all(%0, %1) + %3 = "shape.assuming"(%2) ({ + %cs = shape.any(%cs0, %cs1, %cs2) + shape.assuming_yield %cs : !shape.shape + }) : (!shape.witness) -> !shape.shape + return %3 : !shape.shape +} + +// ----- +// Confirm that cstr_broadcastable of const shapes is folded away. +// CHECK-LABEL: func @f +func @f() -> !shape.shape { + // CHECK: %[[CS:.*]] = shape.const_shape [1, 5] + // CHECK-NEXT: return %[[CS]] + %cs0 = shape.const_shape [1, 1] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable(%cs0, %cs1) + %3 = "shape.assuming"(%0) ({ + shape.assuming_yield %cs1 : !shape.shape + }) : (!shape.witness) -> !shape.shape + return %3 : !shape.shape +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -73,7 +73,8 @@ %1 = shape.const_shape [1, 2, 3] %w0 = shape.cstr_broadcastable(%0, %1) %w1 = shape.cstr_eq(%0, %1) - %w3 = shape.assuming_all(%w0, %w1) + %w2 = shape.true_witness + %w3 = shape.assuming_all(%w0, %w1, %w2) shape.assuming %w3 -> !shape.shape { %2 = shape.any(%0, %1) shape.assuming_yield %2 : !shape.shape