diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -428,19 +428,21 @@ TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"second operand v2 and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>, + Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2, + I64ArrayAttr:$mask)>, Results<(outs AnyVector:$vector)> { let summary = "shuffle operation"; let description = [{ The shuffle operation constructs a permutation (or duplication) of elements from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input - vectors must have the same element type, rank, and trailing dimension sizes - and shuffles their values in the leading dimension (which may differ in size) - according to the given mask. The legality rules are: + vectors must have the same element type, rank (except for 0-D vectors, see + below), and trailing dimension sizes and shuffles their values in the + leading dimension (which may differ in size) according to the given mask. + The legality rules are: * the two operands must have the same element type as the result - * the two operands and the result must have the same rank and trailing - dimension sizes, viz. given two k-D operands + * the two operands and the result must either be 0-Dvectors or have the same + rank and trailing dimension sizes, viz. given two k-D operands v1 : and v2 : we have s_i = t_i for all 1 < i <= k @@ -458,6 +460,8 @@ : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32> %2 = vector.shuffle %a, %b[3, 2, 1, 0] : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> + %3 = vector.shuffle %a, %b[0, 1] + : vector, vector ; yields vector<2xf32> ``` }]; let builders = [ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -464,13 +464,12 @@ // Get rank and dimension sizes. int64_t rank = vectorType.getRank(); - assert(v1Type.getRank() == rank); - assert(v2Type.getRank() == rank); - int64_t v1Dim = v1Type.getDimSize(0); + assert(v1Type.getRank() == 0 || v1Type.getRank() == rank); + assert(v2Type.getRank() == 0 || v2Type.getRank() == rank); - // For rank 1, where both operands have *exactly* the same vector type, - // there is direct shuffle support in LLVM. Use it! - if (rank == 1 && v1Type == v2Type) { + // For rank 0 and 1, where both operands have *exactly* the same vector + // type, there is direct shuffle support in LLVM. Use it! + if (rank <= 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(shuffleOp, llvmShuffleOp); @@ -478,6 +477,7 @@ } // For all other cases, insert the individual values individually. + int64_t v1Dim = v1Type.getDimSize(0); Type eltType; if (auto arrayType = llvmType.dyn_cast()) eltType = arrayType.getElementType(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1492,7 +1492,8 @@ int64_t resRank = resultType.getRank(); int64_t v1Rank = v1Type.getRank(); int64_t v2Rank = v2Type.getRank(); - if (resRank != v1Rank || v1Rank != v2Rank) + // We currently allow 0-D vectors to be shuffled to create 1-D ones. + if ((v1Rank != 0 && resRank != v1Rank) || v1Rank != v2Rank) return op.emitOpError("rank mismatch"); // Verify all but leading dimension sizes. for (int64_t r = 1; r < v1Rank; ++r) { @@ -1508,7 +1509,8 @@ if (maskLength != resultType.getDimSize(0)) return op.emitOpError("mask length mismatch"); // Verify all indices. - int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); + int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (auto en : llvm::enumerate(maskAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -393,6 +393,18 @@ // ----- +func @shuffle_0D_direct(%arg0: vector) -> vector<3xf32> { + %1 = vector.shuffle %arg0, %arg0 [0, 1, 0] : vector, vector + return %1 : vector<3xf32> +} +// CHECK-LABEL: @shuffle_0D_direct( +// CHECK-SAME: %[[A:.*]]: vector +// CHECK: %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32>, vector<1xf32> +// CHECK: return %[[s]] : vector<3xf32> + +// ----- + func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> { %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xf32> return %1 : vector<2xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -58,6 +58,13 @@ // ----- +func @shuffle_rank_mismatch_0d(%arg0: vector, %arg1: vector<1xf32>) { + // expected-error@+1 {{'vector.shuffle' op rank mismatch}} + %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector<1xf32> +} + +// ----- + func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) { // expected-error@+1 {{'vector.shuffle' op dimension mismatch}} %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2x2xf32>, vector<2x4xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -165,6 +165,13 @@ return %4 : vector<8x16xf32> } +// CHECK-LABEL: @shuffle0D +func @shuffle0D(%a: vector) -> vector<3xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 0] : vector, vector + %1 = vector.shuffle %a, %a[0, 1, 0] : vector, vector + return %1 : vector<3xf32> +} + // CHECK-LABEL: @shuffle1D func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> { // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 2, 3] : vector<2xf32>, vector<2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -120,6 +120,13 @@ return } +func @shuffle_0d(%arg: vector) { + %1 = vector.shuffle %arg, %arg [0, 1, 0] : vector, vector + // CHECK: ( 42, 42, 42 ) + vector.print %1: vector<3xi32> + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -150,6 +157,7 @@ call @fma_0d(%5) : (vector) -> () %6 = arith.constant dense<42> : vector call @transpose_0d(%6) : (vector) -> () + call @shuffle_0d(%6) : (vector) -> () return }