diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -470,8 +470,12 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (shapes().size() == 1) + if (shapes().size() == 1) { + // Otherwise, we need a cast which would be a canonicalization, not folding. + if (shapes().front().getType() != getType()) + return nullptr; return shapes().front(); + } // TODO: Support folding with more than 2 input shapes if (shapes().size() > 2) @@ -556,8 +560,10 @@ PatternRewriter &rewriter) const override { if (op.getNumOperands() == 1) { Value uniqueShapeOperand = op.shapes().front(); - rewriter.replaceOp(op, uniqueShapeOperand); - return success(); + if (uniqueShapeOperand.getType() == op.getType()) { + rewriter.replaceOp(op, uniqueShapeOperand); + return success(); + } } return failure(); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1217,10 +1217,21 @@ // ----- // CHECK-LABEL: @broadcast_on_single_operand -// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>) -func @broadcast_on_single_operand(%a : tensor<3xindex>) { +// CHECK-SAME: (%[[A:.*]]: tensor) +func @broadcast_on_single_operand(%a : tensor) { // CHECK-NOT: broadcast // CHECK: "use"(%[[A]]) + %0 = shape.broadcast %a : tensor -> tensor + "use"(%0) : (tensor) -> () + return +} + +// ----- + +// CHECK-LABEL: @broadcast_on_single_operand +// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>) +func @broadcast_on_single_operand(%a : tensor<3xindex>) { + // CHECK: broadcast %[[A]] %0 = shape.broadcast %a : tensor<3xindex> -> tensor "use"(%0) : (tensor) -> () return