diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -277,9 +277,10 @@ }; }]; + let hasCanonicalizer = 1; + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; let verifier = [{ return ::verify(*this); }]; - } def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -779,6 +779,44 @@ return success(); } +namespace { +struct IsBroadcastableCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Find unique operands. + SmallVector unique; + for (Value v : op.getOperands()) { + if (llvm::none_of(unique, [&](Value u) { return u == v; })) + unique.push_back(v); + } + + // Can always broadcast fewer than two shapes. + if (unique.size() < 2) { + rewriter.replaceOpWithNewOp(op, + rewriter.getBoolAttr(true)); + return success(); + } + + // Reduce op to equivalent with unique operands. + if (unique.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, rewriter.getI1Type(), + unique); + return success(); + } + + return failure(); + } +}; +} // namespace + +void IsBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1069,3 +1069,28 @@ %1 = tensor.cast %0 : tensor<1xindex> to tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @is_broadcastable_on_same_shape +func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 { + // CHECK-NOT: is_broadcastable + // CHECK: %[[RES:.*]] = constant true + // CHECK: return %[[RES]] + %0 = shape.is_broadcastable %shape, %shape, %shape + : !shape.shape, !shape.shape, !shape.shape + return %0 : i1 +} + +// ----- + +// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape) +func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape) + -> i1 { + // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] + // CHECK: return %[[RES]] + %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape, + !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape + return %0 : i1 +}