diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -300,6 +300,18 @@ } }; +/// Return the number of leftmost dimensions from the first rightmost transposed +/// dimension found in 'transpose'. +size_t getNumDimsFromFirstTransposedDim(ArrayRef transpose) { + size_t numTransposedDims = transpose.size(); + for (size_t transpDim : llvm::reverse(transpose)) { + if (transpDim != numTransposedDims - 1) + break; + numTransposedDims--; + } + return numTransposedDims; +} + /// Progressive lowering of TransposeOp. /// One: /// %x = vector.transpose %y, [1, 0] @@ -351,35 +363,51 @@ return success(); } - // Generate fully unrolled extract/insert ops. + // Generate unrolled extract/insert ops. We do not unroll the rightmost + // (i.e., highest-order) dimensions that are not transposed and leave them + // in vector form to improve performance. + size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp); + + // The type of the extract operation will be scalar if all the dimensions + // are unrolled. Otherwise, it will be a vector with the shape of the + // dimensions that are not transposed. + Type extractType = + numLeftmostTransposedDims == transp.size() + ? resType.getElementType() + : VectorType::Builder(resType).setShape( + resType.getShape().drop_front(numLeftmostTransposedDims)); + Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); - SmallVector lhs(transp.size(), 0); - SmallVector rhs(transp.size(), 0); - rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, - op.vector(), result, rewriter)); + SmallVector lhs(numLeftmostTransposedDims, 0); + SmallVector rhs(numLeftmostTransposedDims, 0); + rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0, + numLeftmostTransposedDims, transp, lhs, + rhs, op.vector(), result, rewriter)); return success(); } private: // Builds the indices arrays for the lhs and rhs. Generates the extract/insert - // operation when al ranks are exhausted. - Value expandIndices(Location loc, VectorType resType, int64_t pos, + // operations when all the ranks go over the last dimension being transposed. + Value expandIndices(Location loc, VectorType resType, Type extractType, + int64_t pos, int64_t numLeftmostTransposedDims, SmallVector &transp, SmallVector &lhs, SmallVector &rhs, Value input, Value result, PatternRewriter &rewriter) const { - if (pos >= resType.getRank()) { + if (pos >= numLeftmostTransposedDims) { auto ridx = rewriter.getI64ArrayAttr(rhs); auto lidx = rewriter.getI64ArrayAttr(lhs); - Type eltType = resType.getElementType(); - Value e = rewriter.create(loc, eltType, input, ridx); + Value e = + rewriter.create(loc, extractType, input, ridx); return rewriter.create(loc, resType, e, result, lidx); } for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { lhs[pos] = d; rhs[transp[pos]] = d; - result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input, + result = expandIndices(loc, resType, extractType, pos + 1, + numLeftmostTransposedDims, transp, lhs, rhs, input, result, rewriter); } return result; diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -577,8 +577,25 @@ // ----- +// ELTWISE-LABEL: func @transpose102_1x8x8xf32 // AVX2-LABEL: func @transpose102_1x8x8 func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } @@ -587,8 +604,25 @@ // ----- +// ELTWISE-LABEL: func @transpose102_8x1x8xf32 // AVX2-LABEL: func @transpose102_8x1x8 func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + // ELTWISE: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> + // ELTWISE-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } @@ -597,6 +631,20 @@ // ----- +// ELTWISE-LABEL: func @transpose1023_1x1x8x8xf32( +// AVX2-LABEL: func @transpose1023_1x1x8x8 +func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { + // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! + // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32> + // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> + return %0 : vector<1x1x8x8xf32> +} + +// AVX2-NOT: vector.shuffle + +// ----- + // AVX2-LABEL: func @transpose120_1x8x8 func @transpose120_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> {