diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -390,6 +390,7 @@ return vector().getType().cast(); } }]; + let hasCanonicalizer = 1; } def Vector_ExtractSlicesOp : diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -180,6 +180,9 @@ /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos); + /// Returns the map consisting of the most minor `numResults` results. + AffineMap getMinorSubMap(unsigned numResults); + friend ::llvm::hash_code hash_value(AffineMap arg); private: diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -469,8 +469,7 @@ return res; } -Optional> -ContractionOp::getShapeForUnroll() { +Optional> ContractionOp::getShapeForUnroll() { SmallVector shape; getIterationBounds(shape); return shape; @@ -572,6 +571,99 @@ return success(); } +static SmallVector extractUnsignedVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr attr) { return static_cast(attr.getInt()); })); +} + +namespace { + +/// Fold InsertOp -> ... -> ExtractOp to the same position, with potentially +/// multiple interleaved InsertOp and TransposeOp. +struct InsertTransposeExtractFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = extractOp.getContext(); + AffineMap permutationMap; + auto extractedPos = extractUnsignedVector(extractOp.position()); + // Walk back a chain of InsertOp/TransposeOp until we hit a match. + // Compose TransposeOp permutations as we walk back. + auto insertOp = extractOp.vector().getDefiningOp(); + auto transposeOp = extractOp.vector().getDefiningOp(); + while (insertOp || transposeOp) { + // The candidate value that should be inspected. + Value candidate = transposeOp ? transposeOp.vector() : insertOp.dest(); + if (transposeOp) { + // If it is transposed, compose the map and iterate. + auto permutation = extractUnsignedVector(transposeOp.transp()); + AffineMap newMap = AffineMap::getPermutationMap(permutation, context); + if (!permutationMap) { + permutationMap = newMap; + } else { + if (newMap.getNumInputs() != permutationMap.getNumResults()) + return failure(); + permutationMap = newMap.compose(permutationMap); + } + } else { + // If it is inserted into, either the position matches and we have a + // successful folding; or we iterate until we run out of + // InsertOp/TransposeOp. + auto insertedPos = extractUnsignedVector(insertOp.position()); + // Trivial permutations require equality checks. + if (!permutationMap || permutationMap.isIdentity()) { + if (extractedPos == insertedPos) { + rewriter.replaceOp(extractOp, insertOp.source()); + return success(); + } + } else { + // More advanced permutations require application of the permutation. + // However, it is possible `insertedPos.size()` is different from + // `permutationMap.getNumInputs()`. In this case, we need to: + // 1. apply on the `insertedPos.size()` major dimensions + // 2. check the other dimensions of the permutation is an identity. + assert(permutationMap.isPermutation() && "expected a permutation"); + if (insertedPos.size() == extractedPos.size()) { + bool fold = true; + for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) { + auto pos = permutationMap.getResult(idx) + .cast() + .getPosition(); + if (pos >= sz || insertedPos[pos] != extractedPos[idx]) { + fold = false; + break; + } + } + if (fold) { + assert(permutationMap.getNumResults() >= insertedPos.size() && + "expected map of rank larger than insert indexing"); + unsigned minorRank = + permutationMap.getNumResults() - insertedPos.size(); + AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); + if (!minorMap || AffineMap::isMinorIdentity(minorMap)) { + rewriter.replaceOp(extractOp, insertOp.source()); + return success(); + } + } + } + } + } + insertOp = candidate.getDefiningOp(); + transposeOp = candidate.getDefiningOp(); + } + return failure(); + } +}; + +} // namespace + +void ExtractOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ExtractSlicesOp //===----------------------------------------------------------------------===// @@ -1529,8 +1621,7 @@ return OpFoldResult(); } -Optional> -TransferReadOp::getShapeForUnroll() { +Optional> TransferReadOp::getShapeForUnroll() { auto s = getVectorType().getShape(); return SmallVector{s.begin(), s.end()}; } @@ -1625,8 +1716,7 @@ return foldMemRefCast(*this); } -Optional> -TransferWriteOp::getShapeForUnroll() { +Optional> TransferWriteOp::getShapeForUnroll() { auto s = getVectorType().getShape(); return SmallVector{s.begin(), s.end()}; } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -355,12 +355,24 @@ AffineMap AffineMap::getSubMap(ArrayRef resultPos) { SmallVector exprs; exprs.reserve(resultPos.size()); - for (auto idx : resultPos) { + for (auto idx : resultPos) exprs.push_back(getResult(idx)); - } return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } +AffineMap AffineMap::getMinorSubMap(unsigned numResults) { + if (numResults == 0) + return AffineMap(); + if (numResults > getNumResults()) + return *this; + SmallVector resultPos; + resultPos.reserve(numResults); + for (unsigned i = getNumResults() - numResults, e = getNumResults(); i < e; + ++i) + resultPos.push_back(i); + return getSubMap(resultPos); +} + AffineMap mlir::simplifyAffineMap(AffineMap map) { SmallVector exprs; for (auto e : map.getResults()) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -175,3 +175,123 @@ vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref return %1 : vector<4x8xf32> } + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_2d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32 +func @insert_extract_transpose_2d( + %v: vector<2x3xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32) +-> (f32, f32, f32) +{ + %0 = vector.insert %f0, %v[0, 0] : f32 into vector<2x3xf32> + %1 = vector.insert %f1, %0[0, 1] : f32 into vector<2x3xf32> + %2 = vector.insert %f2, %1[1, 0] : f32 into vector<2x3xf32> + %3 = vector.insert %f3, %2[1, 1] : f32 into vector<2x3xf32> + %4 = vector.transpose %3, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + %5 = vector.insert %f3, %4[1, 0] : f32 into vector<3x2xf32> + %6 = vector.transpose %5, [1, 0] : vector<3x2xf32> to vector<2x3xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0]. + %r1 = vector.extract %3[1, 0] : vector<2x3xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by + // transpose [1, 0]. + %r2 = vector.extract %4[1, 0] : vector<3x2xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0] followed by double + // transpose [1, 0]. + %r3 = vector.extract %6[1, 0] : vector<2x3xf32> + + // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F2]] : f32, f32, f32 + return %r1, %r2, %r3 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_3d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32 +func @insert_extract_transpose_3d( + %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32) +-> (f32, f32, f32, f32) +{ + %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32> + %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32> + %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32> + %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32> + %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> + %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32> + %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> + %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32> + %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0]. + %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by + // transpose[1, 2, 0]. + %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32> + + // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double + // transpose[1, 2, 0]. + %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple + // transpose[1, 2, 0]. + %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32> + + // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32 + return %r1, %r2, %r3, %r4 : f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_3d_2d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32> +func @insert_extract_transpose_3d_2d( + %v: vector<2x3x4xf32>, + %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>) +-> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) +{ + %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32> + %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32> + %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32> + %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32> + %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32> + %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0]. + %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by + // transpose[1, 0, 2]. + %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double + // transpose[1, 0, 2]. + %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32> + + %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> + %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> + %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple + // transpose[1, 2, 0]. + %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32> + + // CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]] + // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> + return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> +}