diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2099,8 +2099,8 @@ def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure]>, - Arguments<(ins AnyVector:$source)>, - Results<(outs AnyVector:$result)> { + Arguments<(ins AnyVectorOfAnyRank:$source)>, + Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "shape_cast casts between vector shapes"; let description = [{ The shape_cast operation casts between an n-D source vector shape and diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4615,6 +4615,13 @@ unsigned rankB = b.size(); assert(rankA < rankB); + auto isOne = [](int64_t v) { return v == 1; }; + + // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape + // casted to a 0-d vector. + if (rankA == 0 && llvm::all_of(b, isOne)) + return true; + unsigned i = 0; unsigned j = 0; while (i < rankA && j < rankB) { @@ -4628,7 +4635,6 @@ // Handle the case when trailing dimensions are of size 1. // Include them into the contiguous sequence. - auto isOne = [](int64_t v) { return v == 1; }; if (i < rankA && llvm::all_of(a.slice(i), isOne)) i = rankA; if (j < rankB && llvm::all_of(b.slice(j), isOne)) diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -458,6 +458,18 @@ return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> } +// CHECK-LABEL: @shape_cast_0d +func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) { + + // CHECK: vector.shape_cast %{{.*}} : vector<1x1x1x1xf32> to vector + %0 = vector.shape_cast %arg0 : vector<1x1x1x1xf32> to vector + + // CHECK: vector.shape_cast %{{.*}} : vector to vector<1x1x1x1xf32> + %1 = vector.shape_cast %0 : vector to vector<1x1x1x1xf32> + + return %1 : vector<1x1x1x1xf32> +} + // CHECK-LABEL: @bitcast func.func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>,