diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -414,11 +414,26 @@ return failure(); } }; + +struct BroadcastForwardSingleOperandPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() == 1) { + rewriter.replaceOp(op, op.shapes().front()); + return success(); + } + return failure(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert>(context); + patterns.insert>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1119,3 +1119,15 @@ !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape return %0 : !shape.shape } + +// ----- + +// CHECK-LABEL: @broadcast_on_single_operand +// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>) +func @broadcast_on_single_operand(%a : tensor<3xindex>) { + // CHECK-NOT: broadcast + // CHECK: "use"(%[[A]]) + %0 = shape.broadcast %a : tensor<3xindex> -> tensor + "use"(%0) : (tensor) -> () + return +}