diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -32,19 +32,6 @@ /// `[0, permutation.size())`. bool isPermutation(ArrayRef permutation); -/// Apply the permutation defined by `permutation` to `inVec`. -/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. -/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector -/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. -template -void applyPermutationToVector(SmallVector &inVec, - ArrayRef permutation) { - SmallVector auxVec(inVec.size()); - for (const auto &en : enumerate(permutation)) - auxVec[en.index()] = inVec[en.value()]; - inVec = auxVec; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -30,6 +30,19 @@ SmallVector delinearize(ArrayRef strides, int64_t linearIndex); +/// Apply the permutation defined by `permutation` to `inVec`. +/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector +/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. +template +void applyPermutationToVector(SmallVector &inVec, + ArrayRef permutation) { + SmallVector auxVec(inVec.size()); + for (const auto &en : enumerate(permutation)) + auxVec[en.index()] = inVec[en.value()]; + inVec = auxVec; +} + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AsmState.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/Transforms/FoldUtils.h" diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -300,16 +301,18 @@ } }; -/// Return the number of leftmost dimensions from the first rightmost transposed -/// dimension found in 'transpose'. -size_t getNumDimsFromFirstTransposedDim(ArrayRef transpose) { +/// Given a 'transpose' pattern, prune the rightmost dimensions that are not +/// transposed. +void pruneNonTransposedDims(ArrayRef transpose, + SmallVectorImpl &result) { size_t numTransposedDims = transpose.size(); for (size_t transpDim : llvm::reverse(transpose)) { if (transpDim != numTransposedDims - 1) break; numTransposedDims--; } - return numTransposedDims; + + result.append(transpose.begin(), transpose.begin() + numTransposedDims); } /// Progressive lowering of TransposeOp. @@ -334,6 +337,8 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + Value input = op.vector(); + VectorType inputType = op.getVectorType(); VectorType resType = op.getResultType(); // Set up convenience transposition table. @@ -354,7 +359,7 @@ Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); auto matrix = - rewriter.create(loc, flattenedType, op.vector()); + rewriter.create(loc, flattenedType, input); auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); Value trans = rewriter.create( @@ -365,54 +370,40 @@ // Generate unrolled extract/insert ops. We do not unroll the rightmost // (i.e., highest-order) dimensions that are not transposed and leave them - // in vector form to improve performance. - size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp); - - // The type of the extract operation will be scalar if all the dimensions - // are unrolled. Otherwise, it will be a vector with the shape of the - // dimensions that are not transposed. - Type extractType = - numLeftmostTransposedDims == transp.size() - ? resType.getElementType() - : VectorType::Builder(resType).setShape( - resType.getShape().drop_front(numLeftmostTransposedDims)); - + // in vector form to improve performance. Therefore, we prune those + // dimensions from the shape/transpose data structures used to generate the + // extract/insert ops. + SmallVector prunedTransp; + pruneNonTransposedDims(transp, prunedTransp); + size_t numPrunedDims = transp.size() - prunedTransp.size(); + auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); + SmallVector ones(prunedInShape.size(), 1); + auto prunedInStrides = computeStrides(prunedInShape, ones); + + // Generates the extract/insert operations for every scalar/vector element + // of the leftmost transposed dimensions. We traverse every transpose + // element using a linearized index that we delinearize to generate the + // appropriate indices for the extract/insert operations. Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); - SmallVector lhs(numLeftmostTransposedDims, 0); - SmallVector rhs(numLeftmostTransposedDims, 0); - rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0, - numLeftmostTransposedDims, transp, lhs, - rhs, op.vector(), result, rewriter)); + int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); + + for (int64_t linearIdx = 0; linearIdx < numTransposedElements; + ++linearIdx) { + auto extractIdxs = delinearize(prunedInStrides, linearIdx); + SmallVector insertIdxs(extractIdxs); + applyPermutationToVector(insertIdxs, prunedTransp); + Value extractOp = + rewriter.create(loc, input, extractIdxs); + result = + rewriter.create(loc, extractOp, result, insertIdxs); + } + + rewriter.replaceOp(op, result); return success(); } private: - // Builds the indices arrays for the lhs and rhs. Generates the extract/insert - // operations when all the ranks go over the last dimension being transposed. - Value expandIndices(Location loc, VectorType resType, Type extractType, - int64_t pos, int64_t numLeftmostTransposedDims, - SmallVector &transp, - SmallVector &lhs, - SmallVector &rhs, Value input, Value result, - PatternRewriter &rewriter) const { - if (pos >= numLeftmostTransposedDims) { - auto ridx = rewriter.getI64ArrayAttr(rhs); - auto lidx = rewriter.getI64ArrayAttr(lhs); - Value e = - rewriter.create(loc, extractType, input, ridx); - return rewriter.create(loc, resType, e, result, lidx); - } - for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { - lhs[pos] = d; - rhs[transp[pos]] = d; - result = expandIndices(loc, resType, extractType, pos + 1, - numLeftmostTransposedDims, transp, lhs, rhs, input, - result, rewriter); - } - return result; - } - /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformOptions; }; diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -8,14 +8,14 @@ // ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> // ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> // ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> +// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> // ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> // ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> // ELTWISE: return %[[T11]] : vector<3x2xf32>