diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -170,6 +170,10 @@ /// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)` AffineMap compose(AffineMap map); + /// Applies composition by the dims of `this` to the integer `values` and + /// returns the resulting values. `this` must be symbol-less. + SmallVector compose(ArrayRef values); + /// Returns true if the AffineMap represents a subset (i.e. a projection) of a /// symbol-less permutation map. bool isProjectedPermutation(); @@ -180,6 +184,11 @@ /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos); + /// Returns the map consisting of the most major `numResults` results. + /// Returns the null AffineMap if `numResults` == 0. + /// Returns `*this` if `numResults` >= `this->getNumResults()`. + AffineMap getMajorSubMap(unsigned numResults); + /// Returns the map consisting of the most minor `numResults` results. /// Returns the null AffineMap if `numResults` == 0. /// Returns `*this` if `numResults` >= `this->getNumResults()`. diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -602,6 +603,63 @@ return success(); } +/// Fold the result of an ExtractOp in place when it comes from a TransposeOp. +static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) { + auto transposeOp = extractOp.vector().getDefiningOp(); + if (!transposeOp) + return failure(); + + auto permutation = extractVector(transposeOp.transp()); + auto extractedPos = extractVector(extractOp.position()); + + // If transposition permutation is larger than the ExtractOp, all minor + // dimensions must be an identity for folding to occur. If not, individual + // elements within the extracted value are transposed and this is not just a + // simple folding. + unsigned minorRank = permutation.size() - extractedPos.size(); + MLIRContext *ctx = extractOp.getContext(); + AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx); + AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); + if (minorMap && !AffineMap::isMinorIdentity(minorMap)) + return failure(); + + // %1 = transpose %0[x, y, z] : vector + // %2 = extract %1[u, v] : vector<..xf32> + // may turn into: + // %2 = extract %0[w, x] : vector<..xf32> + // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and + // -1 denotes the inverse. + permutationMap = permutationMap.getMajorSubMap(extractedPos.size()); + // The major submap has fewer results but the same number of dims. To compose + // cleanly, we need to drop dims to form a "square matrix". This is possible + // because: + // (a) this is a permutation map and + // (b) the minor map has already been checked to be identity. + // Therefore, the major map cannot contain dims of position greater or equal + // than the number of results. + assert(llvm::all_of(permutationMap.getResults(), + [&](AffineExpr e) { + auto dim = e.dyn_cast(); + return dim && dim.getPosition() < + permutationMap.getNumResults(); + }) && + "Unexpected map results depend on higher rank positions"); + // Project on the first domain dimensions to allow composition. + permutationMap = AffineMap::get(permutationMap.getNumResults(), 0, + permutationMap.getResults(), ctx); + + extractOp.setOperand(transposeOp.vector()); + // Compose the inverse permutation map with the extractedPos. + auto newExtractedPos = + inversePermutation(permutationMap).compose(extractedPos); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + extractOp.setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(newExtractedPos)); + + return success(); +} + /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The /// result is always the input to some InsertOp. static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { @@ -689,6 +747,8 @@ OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); + if (succeeded(foldExtractOpFromTranspose(*this))) + return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; return OpFoldResult(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -330,6 +330,21 @@ return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); } +SmallVector AffineMap::compose(ArrayRef values) { + assert(getNumSymbols() == 0 && "Expected symbol-less map"); + SmallVector exprs; + exprs.reserve(values.size()); + MLIRContext *ctx = getContext(); + for (auto v : values) + exprs.push_back(getAffineConstantExpr(v, ctx)); + auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); + SmallVector res; + res.reserve(resMap.getNumResults()); + for (auto e : resMap.getResults()) + res.push_back(e.cast().getValue()); + return res; +} + bool AffineMap::isProjectedPermutation() { if (getNumSymbols() > 0) return false; @@ -360,6 +375,14 @@ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } +AffineMap AffineMap::getMajorSubMap(unsigned numResults) { + if (numResults == 0) + return AffineMap(); + if (numResults > getNumResults()) + return *this; + return getSubMap(llvm::to_vector<4>(llvm::seq(0, numResults))); +} + AffineMap AffineMap::getMinorSubMap(unsigned numResults) { if (numResults == 0) return AffineMap(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -300,13 +300,48 @@ // CHECK-LABEL: fold_extracts // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32> -// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32> -// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32> -// CHECK-NEXT: return func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) { %b = vector.extract %a[0] : vector<3x4x5x6xf32> %c = vector.extract %b[1, 2] : vector<4x5x6xf32> + // CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32> %d = vector.extract %c[3] : vector<6xf32> + + // CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32> %e = vector.extract %a[0] : vector<3x4x5x6xf32> + + // CHECK-NEXT: return return %d, %e : f32, vector<4x5x6xf32> } + +// ----- + +// CHECK-LABEL: fold_extract_transpose +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32> +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x6x5x6xf32> +func @fold_extract_transpose( + %a : vector<3x4x5x6xf32>, %b : vector<3x6x5x6xf32>) -> ( + vector<6xf32>, vector<6xf32>, vector<6xf32>) { + // [3] is a proper most minor identity map in transpose. + // Permutation is a self inverse and we have. + // [0, 2, 1] ^ -1 o [0, 1, 2] = [0, 2, 1] o [0, 1, 2] + // = [0, 2, 1] + // CHECK-NEXT: vector.extract %[[A]][0, 2, 1] : vector<3x4x5x6xf32> + %0 = vector.transpose %a, [0, 2, 1, 3] : vector<3x4x5x6xf32> to vector<3x5x4x6xf32> + %1 = vector.extract %0[0, 1, 2] : vector<3x5x4x6xf32> + + // [3] is a proper most minor identity map in transpose. + // Permutation is a not self inverse and we have. + // [1, 2, 0] ^ -1 o [0, 1, 2] = [2, 0, 1] o [0, 1, 2] + // = [2, 0, 1] + // CHECK-NEXT: vector.extract %[[A]][2, 0, 1] : vector<3x4x5x6xf32> + %2 = vector.transpose %a, [1, 2, 0, 3] : vector<3x4x5x6xf32> to vector<4x5x3x6xf32> + %3 = vector.extract %2[0, 1, 2] : vector<4x5x3x6xf32> + + // Not a minor identity map so intra-vector level has been permuted + // CHECK-NEXT: vector.transpose %[[B]], [0, 2, 3, 1] + // CHECK-NEXT: vector.extract %{{.*}}[0, 1, 2] + %4 = vector.transpose %b, [0, 2, 3, 1] : vector<3x6x5x6xf32> to vector<3x5x6x6xf32> + %5 = vector.extract %4[0, 1, 2] : vector<3x5x6x6xf32> + + return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32> +}