diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -463,7 +463,7 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } -def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> { +def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let summary = "Determines if 2 shapes can be successfully broadcasted."; let description = [{ Given 2 input shapes, return a witness specifying if they are broadcastable. @@ -482,6 +482,8 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict"; + + let hasCanonicalizer = 1; } def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -255,6 +255,44 @@ return success(); } +//===----------------------------------------------------------------------===// +// CstrBroadcastableOp +//===----------------------------------------------------------------------===// + +namespace { +// TODO: Add a case for unknown shapes that are still defined by the same +// operation. +// TODO: Once Witnesses are Attributes, replace this with folding. +struct CstrBroadcastableTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Don't try to compare equality when the shapes are not constant. + auto lhs = op.getOperand(0).getDefiningOp(); + auto rhs = op.getOperand(1).getDefiningOp(); + if (!lhs || !rhs) + return failure(); + + SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. + if (!OpTrait::util::getBroadcastedShape( + llvm::to_vector<6>(lhs.shape().getValues()), + llvm::to_vector<6>(rhs.shape().getValues()), resultShape)) + return failure(); + + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; +}; // namespace + +void CstrBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // If equal, return true op + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -159,3 +159,47 @@ %1 = shape.any %arg0, %arg1 return %1 : !shape.shape } + +// ----- +// Broadcastable with broadcastable constant shapes can be removed. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: true_witness + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [3, 1] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable with non-broadcastable constant shapes cannot be removed. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1, 3] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable without constant shapes cannot be removed. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1,3] + %0 = shape.cstr_broadcastable %arg0, %cs0 + "consume.witness"(%0) : (!shape.witness) -> () + return +}