diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -30,6 +30,11 @@ SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides, int64_t linearIndex); +/// Given an 'input' array and a transpose 'pattern', returns a transposed copy +/// of 'input'. +SmallVector<int64_t, 4> transpose(ArrayRef<int64_t> input, + ArrayRef<int64_t> pattern); + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -30,6 +30,18 @@ return vectorOffsets; } +llvm::SmallVector<int64_t, 4> mlir::transpose(ArrayRef<int64_t> input, + ArrayRef<int64_t> pattern) { + int64_t rank = input.size(); + assert(rank <= pattern.size() && "Not enough patterns"); + SmallVector<int64_t, 4> output(rank); + for (int64_t i = 0; i < rank; ++i) { + assert(pattern[i] >= 0 && "Unexpected transpose pattern"); + output[i] = input[pattern[i]]; + } + return output; +} + llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront, unsigned dropBack) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -334,6 +335,8 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + Value input = op.vector(); + VectorType inputType = op.getVectorType(); VectorType resType = op.getResultType(); // Set up convenience transposition table. @@ -354,7 +357,7 @@ Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); auto matrix = - rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector()); + rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); Value trans = rewriter.create<vector::FlatTransposeOp>( @@ -366,53 +369,47 @@ // Generate unrolled extract/insert ops. We do not unroll the rightmost // (i.e., highest-order) dimensions that are not transposed and leave them // in vector form to improve performance. - size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp); + size_t numLeftmostTranspDims = getNumDimsFromFirstTransposedDim(transp); + size_t numRightmostNonTranspDims = transp.size() - numLeftmostTranspDims; // The type of the extract operation will be scalar if all the dimensions // are unrolled. Otherwise, it will be a vector with the shape of the // dimensions that are not transposed. Type extractType = - numLeftmostTransposedDims == transp.size() + numLeftmostTranspDims == transp.size() ? resType.getElementType() : VectorType::Builder(resType).setShape( - resType.getShape().drop_front(numLeftmostTransposedDims)); + resType.getShape().drop_front(numLeftmostTranspDims)); Value result = rewriter.create<arith::ConstantOp>( loc, resType, rewriter.getZeroAttr(resType)); - SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0); - SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0); - rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0, - numLeftmostTransposedDims, transp, lhs, - rhs, op.vector(), result, rewriter)); - return success(); - } -private: - // Builds the indices arrays for the lhs and rhs. Generates the extract/insert - // operations when all the ranks go over the last dimension being transposed. - Value expandIndices(Location loc, VectorType resType, Type extractType, - int64_t pos, int64_t numLeftmostTransposedDims, - SmallVector<int64_t, 4> &transp, - SmallVector<int64_t, 4> &lhs, - SmallVector<int64_t, 4> &rhs, Value input, Value result, - PatternRewriter &rewriter) const { - if (pos >= numLeftmostTransposedDims) { - auto ridx = rewriter.getI64ArrayAttr(rhs); - auto lidx = rewriter.getI64ArrayAttr(lhs); - Value e = - rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx); - return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx); - } - for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { - lhs[pos] = d; - rhs[transp[pos]] = d; - result = expandIndices(loc, resType, extractType, pos + 1, - numLeftmostTransposedDims, transp, lhs, rhs, input, - result, rewriter); + auto inputShape = inputType.getShape().drop_back(numRightmostNonTranspDims); + SmallVector<int64_t, 4> ones(numLeftmostTranspDims, 1); + SmallVector<int64_t, 4> inputStrides = computeStrides(inputShape, ones); + int64_t numTransposedElements = ShapedType::getNumElements(inputShape); + + // Generates the extract/insert operations for every scalar/vector element + // of the leftmost transposed dimensions. We traverse every transpose + // element using a linearized index that we delinearize to generate the + // appropriate indices for the extract/insert operations. + for (int64_t linearIdx = 0; linearIdx < numTransposedElements; + ++linearIdx) { + SmallVector<int64_t, 4> inputIdxs = delinearize(inputStrides, linearIdx); + SmallVector<int64_t, 4> transpIdxs = transpose(inputIdxs, transp); + auto inputIdxAttrs = rewriter.getI64ArrayAttr(inputIdxs); + auto transpIdxAttrs = rewriter.getI64ArrayAttr(transpIdxs); + Value extractOp = rewriter.create<vector::ExtractOp>( + loc, extractType, input, inputIdxAttrs); + result = rewriter.create<vector::InsertOp>(loc, resType, extractOp, + result, transpIdxAttrs); } - return result; + + rewriter.replaceOp(op, result); + return success(); } +private: /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformOptions; }; diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -8,14 +8,14 @@ // ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> // ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> // ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> +// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> // ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> // ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> // ELTWISE: return %[[T11]] : vector<3x2xf32>