diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -437,22 +437,25 @@ PredOpTrait<"second operand v2 and result have same element type", TCresVTEtIsSameAsOpBase<0, 1>>, DeclareOpInterfaceMethods]>, - Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>, + Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2, + I64ArrayAttr:$mask)>, Results<(outs AnyVector:$vector)> { let summary = "shuffle operation"; let description = [{ The shuffle operation constructs a permutation (or duplication) of elements from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input - vectors must have the same element type, rank, and trailing dimension sizes - and shuffles their values in the leading dimension (which may differ in size) - according to the given mask. The legality rules are: + vectors must have the same element type, same rank , and trailing dimension + sizes and shuffles their values in the + leading dimension (which may differ in size) according to the given mask. + The legality rules are: * the two operands must have the same element type as the result - * the two operands and the result must have the same rank and trailing - dimension sizes, viz. given two k-D operands - v1 : and - v2 : - we have s_i = t_i for all 1 < i <= k + - Either, the two operands and the result must have the same + rank and trailing dimension sizes, viz. given two k-D operands + v1 : and + v2 : + we have s_i = t_i for all 1 < i <= k + - Or, the two operands must be 0-D vectors and the result is a 1-D vector. * the mask length equals the leading dimension size of the result * numbering the input vector indices left to right across the operands, all mask values must be within range, viz. given two k-D operands v1 and v2 @@ -467,12 +470,15 @@ : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32> %2 = vector.shuffle %a, %b[3, 2, 1, 0] : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> + %3 = vector.shuffle %a, %b[0, 1] + : vector, vector ; yields vector<2xf32> ``` }]; let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef")> ]; let hasFolder = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ static StringRef getMaskAttrStrName() { return "mask"; } VectorType getV1VectorType() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -595,13 +595,15 @@ // Get rank and dimension sizes. int64_t rank = vectorType.getRank(); - assert(v1Type.getRank() == rank); - assert(v2Type.getRank() == rank); - int64_t v1Dim = v1Type.getDimSize(0); - - // For rank 1, where both operands have *exactly* the same vector type, - // there is direct shuffle support in LLVM. Use it! - if (rank == 1 && v1Type == v2Type) { + bool wellFormed0DCase = + v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1; + bool wellFormedNDCase = + v1Type.getRank() == rank && v2Type.getRank() == rank; + assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed"); + + // For rank 0 and 1, where both operands have *exactly* the same vector + // type, there is direct shuffle support in LLVM. Use it! + if (rank <= 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create( loc, adaptor.getV1(), adaptor.getV2(), LLVM::convertArrayToIndices(maskArrayAttr)); @@ -610,6 +612,7 @@ } // For all other cases, insert the individual values individually. + int64_t v1Dim = v1Type.getDimSize(0); Type eltType; if (auto arrayType = llvmType.dyn_cast()) eltType = arrayType.getElementType(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1794,8 +1794,11 @@ int64_t resRank = resultType.getRank(); int64_t v1Rank = v1Type.getRank(); int64_t v2Rank = v2Type.getRank(); - if (resRank != v1Rank || v1Rank != v2Rank) + bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1; + bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank; + if (!wellFormed0DCase && !wellFormedNDCase) return emitOpError("rank mismatch"); + // Verify all but leading dimension sizes. for (int64_t r = 1; r < v1Rank; ++r) { int64_t resDim = resultType.getDimSize(r); @@ -1812,7 +1815,8 @@ if (maskLength != resultType.getDimSize(0)) return emitOpError("mask length mismatch"); // Verify all indices. - int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); + int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (const auto &en : llvm::enumerate(maskAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) @@ -1828,12 +1832,15 @@ SmallVectorImpl &inferredReturnTypes) { ShuffleOp::Adaptor op(operands, attributes); auto v1Type = op.getV1().getType().cast(); - // Construct resulting type: leading dimension matches mask length, - // all trailing dimensions match the operands. + auto v1Rank = v1Type.getRank(); + // Construct resulting type: leading dimension matches mask + // length, all trailing dimensions match the operands. SmallVector shape; - shape.reserve(v1Type.getRank()); + shape.reserve(v1Rank); shape.push_back(std::max(1, op.getMask().size())); - llvm::append_range(shape, v1Type.getShape().drop_front()); + // In the 0-D case there is no trailing shape to append. + if (v1Rank > 0) + llvm::append_range(shape, v1Type.getShape().drop_front()); inferredReturnTypes.push_back( VectorType::get(shape, v1Type.getElementType())); return success(); @@ -1849,9 +1856,15 @@ } OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { + VectorType v1Type = getV1VectorType(); + // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding + // but must be a canonicalization into a vector.broadcast. + if (v1Type.getRank() == 0) + return {}; + // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1 - if (!getV1VectorType().isScalable() && - isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0))) + if (!v1Type.isScalable() && + isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) return getV1(); // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() && @@ -1887,6 +1900,30 @@ namespace { +// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector +// to a broadcast. +struct Canonicalize0DShuffleOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShuffleOp shuffleOp, + PatternRewriter &rewriter) const override { + VectorType v1VectorType = shuffleOp.getV1VectorType(); + ArrayAttr mask = shuffleOp.getMask(); + if (v1VectorType.getRank() > 0) + return failure(); + if (mask.size() != 1) + return failure(); + Type resType = VectorType::Builder(v1VectorType).setShape({1}); + if (mask[0].cast().getInt() == 0) + rewriter.replaceOpWithNewOp(shuffleOp, resType, + shuffleOp.getV1()); + else + rewriter.replaceOpWithNewOp(shuffleOp, resType, + shuffleOp.getV2()); + return success(); + } +}; + /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. class ShuffleSplat final : public OpRewritePattern { public: @@ -1912,7 +1949,7 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -416,6 +416,19 @@ // CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> // CHECK: return %[[T19]] : vector<2x3xf32> + +// ----- + +func.func @shuffle_0D_direct(%arg0: vector) -> vector<3xf32> { + %1 = vector.shuffle %arg0, %arg0 [0, 1, 0] : vector, vector + return %1 : vector<3xf32> +} +// CHECK-LABEL: @shuffle_0D_direct( +// CHECK-SAME: %[[A:.*]]: vector +// CHECK: %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32> +// CHECK: return %[[s]] : vector<3xf32> + // ----- func.func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1487,6 +1487,13 @@ return %shuffle : vector<4xi32> } +// CHECK-LABEL: func @shuffle_canonicalize_0d +func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { + // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32> + %shuffle = vector.shuffle %v0, %v1 [0] : vector, vector + return %shuffle : vector<1xi32> +} + // CHECK-LABEL: func @shuffle_fold1 // CHECK: %arg0 : vector<4xi32> func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> { diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -56,6 +56,13 @@ %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<4x2xf32> } +// ----- + +func.func @shuffle_rank_mismatch_0d(%arg0: vector, %arg1: vector<1xf32>) { + // expected-error@+1 {{'vector.shuffle' op rank mismatch}} + %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector<1xf32> +} + // ----- func.func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -165,6 +165,13 @@ return %4 : vector<8x16xf32> } +// CHECK-LABEL: @shuffle0D +func.func @shuffle0D(%a: vector) -> vector<3xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 0] : vector, vector + %1 = vector.shuffle %a, %a[0, 1, 0] : vector, vector + return %1 : vector<3xf32> +} + // CHECK-LABEL: @shuffle1D func.func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> { // CHECK: vector.shuffle %{{.*}}, %{{.*}}[0, 1, 2, 3] : vector<2xf32>, vector<2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -127,6 +127,13 @@ return } +func.func @shuffle_0d(%v0: vector, %v1: vector) { + %1 = vector.shuffle %v0, %v1 [0, 1, 0] : vector, vector + // CHECK: ( 42, 43, 42 ) + vector.print %1: vector<3xi32> + return +} + func.func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -159,7 +166,9 @@ %5 = arith.constant dense<4.0> : vector call @fma_0d(%5) : (vector) -> () %6 = arith.constant dense<42> : vector + %7 = arith.constant dense<43> : vector call @transpose_0d(%6) : (vector) -> () + call @shuffle_0d(%6, %7) : (vector, vector) -> () return }