diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -155,17 +155,18 @@ return lb.create(indexTy, maxRank, v); })); - rewriter.replaceOp( - op, lb.create( - getExtentTensorType(lb.getContext()), ValueRange{maxRank}, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), transformed.shapes(), - rankDiffs, args[0]); - - b.create(loc, broadcastedDim); - }) - ->getResults()); + Value replacement = lb.create( + getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value broadcastedDim = + getBroadcastedDim(ImplicitLocOpBuilder(loc, b), + transformed.shapes(), rankDiffs, args[0]); + + b.create(loc, broadcastedDim); + }); + if (replacement.getType() != op.getType()) + replacement = lb.create(op.getType(), replacement); + rewriter.replaceOp(op, replacement); return success(); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -593,6 +593,17 @@ return } +// ---- + +// CHECK-LABEL: @broadcast_to_known_rank +func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>) + -> tensor<3xindex> { + // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor to tensor<3xindex> + // CHECK: return %[[RES]] : tensor<3xindex> + %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex> + return %0 : tensor<3xindex> +} + // ----- // Lower `split_at`