diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -151,8 +151,26 @@ incIdx(srcIdx, sourceVectorType, srcRank - 1); incIdx(resIdx, resultVectorType, resRank - 1); } - Value e = rewriter.create(loc, op.getSource(), srcIdx); - result = rewriter.create(loc, e, result, resIdx); + + Value extract; + if (srcRank == 0) { + // 0-D vector special case + assert(srcIdx.empty() && "Unexpected indices for 0-D vector"); + extract = rewriter.create( + loc, op.getSourceVectorType().getElementType(), op.getSource()); + } else { + extract = + rewriter.create(loc, op.getSource(), srcIdx); + } + + if (resRank == 0) { + // 0-D vector special case + assert(resIdx.empty() && "Unexpected indices for 0-D vector"); + result = rewriter.create(loc, extract, result); + } else { + result = + rewriter.create(loc, extract, result, resIdx); + } } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -1,4 +1,3 @@ - // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s // CHECK-LABEL: func @nop_shape_cast @@ -124,9 +123,35 @@ return %s : vector<2x1x3xf32> } +// CHECK-LABEL: func.func @shape_cast_0d1d( +// CHECK-SAME: %[[VAL_0:.*]]: vector) -> vector<1xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector +// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32> +// CHECK: return %[[VAL_3]] : vector<1xf32> +// CHECK: } + +func.func @shape_cast_0d1d(%arg0 : vector) -> vector<1xf32> { + %s = vector.shape_cast %arg0 : vector to vector<1xf32> + return %s : vector<1xf32> +} + +// CHECK-LABEL: func.func @shape_cast_1d0d( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector +// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32> +// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector +// CHECK: return %[[VAL_3]] : vector +// CHECK: } + +func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector { + %s = vector.shape_cast %arg0 : vector<1xf32> to vector + return %s : vector +} + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - %f = transform.structured.match ops{["func.func"]} in %module_op + %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op %f2 = transform.vector.lower_shape_cast %f