diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -558,14 +558,26 @@ LogicalResult matchAndRewrite(BroadcastOp op, PatternRewriter &rewriter) const override { - if (op.getNumOperands() == 1) { - Value uniqueShapeOperand = op.shapes().front(); - if (uniqueShapeOperand.getType() == op.getType()) { - rewriter.replaceOp(op, uniqueShapeOperand); - return success(); + if (op.getNumOperands() != 1) + return failure(); + Value replacement = op.shapes().front(); + + // Insert cast if needed. + if (replacement.getType() != op.getType()) { + auto loc = op.getLoc(); + if (op.getType().isa()) { + replacement = rewriter.create(loc, replacement); + } else { + assert(!op.getType().isa() && + !replacement.getType().isa() && + "expect extent tensor cast"); + replacement = + rewriter.create(loc, op.getType(), replacement); } } - return failure(); + + rewriter.replaceOp(op, replacement); + return success(); } }; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1242,13 +1242,24 @@ // ----- -// CHECK-LABEL: @broadcast_on_single_operand +// CHECK-LABEL: @broadcast_as_tensor_cast // CHECK-SAME: (%[[A:.*]]: tensor<3xindex>) -func @broadcast_on_single_operand(%a : tensor<3xindex>) { - // CHECK: broadcast %[[A]] +func @broadcast_as_tensor_cast(%a : tensor<3xindex>) -> tensor { + // CHECK: %[[RESULT:.*]] = tensor.cast %[[A]] : tensor<3xindex> to tensor + // CHECK: return %[[RESULT]] : tensor %0 = shape.broadcast %a : tensor<3xindex> -> tensor - "use"(%0) : (tensor) -> () - return + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @broadcast_as_from_extent_tensor +// CHECK-SAME: (%[[A:.*]]: tensor) +func @broadcast_as_from_extent_tensor(%a : tensor) -> !shape.shape { + // CHECK: %[[RESULT:.*]] = shape.from_extent_tensor %[[A]] : tensor + // CHECK: return %[[RESULT]] : !shape.shape + %0 = shape.broadcast %a : tensor -> !shape.shape + return %0 : !shape.shape } // -----