diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -390,6 +390,7 @@ return vector().getType().cast(); } }]; + let hasFolder = 1; } def Vector_ExtractSlicesOp : diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -180,6 +180,11 @@ /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos); + /// Returns the map consisting of the most minor `numResults` results. + /// Returns the null AffineMap if `numResults` == 0. + /// Returns `*this` if `numResults` >= `this->getNumResults()`. + AffineMap getMinorSubMap(unsigned numResults); + friend ::llvm::hash_code hash_value(AffineMap arg); private: diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -571,6 +571,100 @@ return success(); } +static SmallVector extractUnsignedVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr attr) { return static_cast(attr.getInt()); })); +} + +static Value foldExtractOp(ExtractOp extractOp) { + MLIRContext *context = extractOp.getContext(); + AffineMap permutationMap; + auto extractedPos = extractUnsignedVector(extractOp.position()); + // Walk back a chain of InsertOp/TransposeOp until we hit a match. + // Compose TransposeOp permutations as we walk back. + auto insertOp = extractOp.vector().getDefiningOp(); + auto transposeOp = extractOp.vector().getDefiningOp(); + while (insertOp || transposeOp) { + if (transposeOp) { + // If it is transposed, compose the map and iterate. + auto permutation = extractUnsignedVector(transposeOp.transp()); + AffineMap newMap = AffineMap::getPermutationMap(permutation, context); + if (!permutationMap) + permutationMap = newMap; + else if (newMap.getNumInputs() != permutationMap.getNumResults()) + return Value(); + else + permutationMap = newMap.compose(permutationMap); + // Compute insert/transpose for the next iteration. + Value transposed = transposeOp.vector(); + insertOp = transposed.getDefiningOp(); + transposeOp = transposed.getDefiningOp(); + continue; + } + + assert(insertOp); + Value insertionDest = insertOp.dest(); + // If it is inserted into, either the position matches and we have a + // successful folding; or we iterate until we run out of + // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector` + // produces a new vector with 1 modified value/slice in exactly the static + // position we need to match. + auto insertedPos = extractUnsignedVector(insertOp.position()); + // Trivial permutations are solved with position equality checks. + if (!permutationMap || permutationMap.isIdentity()) { + if (extractedPos == insertedPos) + return insertOp.source(); + // Fallthrough: if the position does not match, just skip to the next + // producing `vector.insert` / `vector.transpose`. + // Compute insert/transpose for the next iteration. + insertOp = insertionDest.getDefiningOp(); + transposeOp = insertionDest.getDefiningOp(); + continue; + } + + // More advanced permutations require application of the permutation. + // However, the rank of `insertedPos` may be different from that of the + // `permutationMap`. To support such case, we need to: + // 1. apply on the `insertedPos.size()` major dimensions + // 2. check the other dimensions of the permutation form a minor identity. + assert(permutationMap.isPermutation() && "expected a permutation"); + if (insertedPos.size() == extractedPos.size()) { + bool fold = true; + for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) { + auto pos = + permutationMap.getResult(idx).cast().getPosition(); + if (pos >= sz || insertedPos[pos] != extractedPos[idx]) { + fold = false; + break; + } + } + if (fold) { + assert(permutationMap.getNumResults() >= insertedPos.size() && + "expected map of rank larger than insert indexing"); + unsigned minorRank = + permutationMap.getNumResults() - insertedPos.size(); + AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); + if (!minorMap || AffineMap::isMinorIdentity(minorMap)) + return insertOp.source(); + } + } + + // If we haven't found a match, just continue to the next producing + // `vector.insert` / `vector.transpose`. + // Compute insert/transpose for the next iteration. + insertOp = insertionDest.getDefiningOp(); + transposeOp = insertionDest.getDefiningOp(); + } + return Value(); +} + +OpFoldResult ExtractOp::fold(ArrayRef) { + if (auto val = foldExtractOp(*this)) + return val; + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // ExtractSlicesOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -355,12 +355,20 @@ AffineMap AffineMap::getSubMap(ArrayRef resultPos) { SmallVector exprs; exprs.reserve(resultPos.size()); - for (auto idx : resultPos) { + for (auto idx : resultPos) exprs.push_back(getResult(idx)); - } return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } +AffineMap AffineMap::getMinorSubMap(unsigned numResults) { + if (numResults == 0) + return AffineMap(); + if (numResults > getNumResults()) + return *this; + return getSubMap(llvm::to_vector<4>( + llvm::seq(getNumResults() - numResults, getNumResults()))); +} + AffineMap mlir::simplifyAffineMap(AffineMap map) { SmallVector exprs; for (auto e : map.getResults()) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -175,3 +175,123 @@ vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref return %1 : vector<4x8xf32> } + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_2d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32 +func @insert_extract_transpose_2d( + %v: vector<2x3xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32) +-> (f32, f32, f32) +{ + %0 = vector.insert %f0, %v[0, 0] : f32 into vector<2x3xf32> + %1 = vector.insert %f1, %0[0, 1] : f32 into vector<2x3xf32> + %2 = vector.insert %f2, %1[1, 0] : f32 into vector<2x3xf32> + %3 = vector.insert %f3, %2[1, 1] : f32 into vector<2x3xf32> + %4 = vector.transpose %3, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + %5 = vector.insert %f3, %4[1, 0] : f32 into vector<3x2xf32> + %6 = vector.transpose %5, [1, 0] : vector<3x2xf32> to vector<2x3xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0]. + %r1 = vector.extract %3[1, 0] : vector<2x3xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by + // transpose [1, 0]. + %r2 = vector.extract %4[1, 0] : vector<3x2xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0] followed by double + // transpose [1, 0]. + %r3 = vector.extract %6[1, 0] : vector<2x3xf32> + + // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F2]] : f32, f32, f32 + return %r1, %r2, %r3 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_3d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32 +func @insert_extract_transpose_3d( + %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32) +-> (f32, f32, f32, f32) +{ + %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32> + %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32> + %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32> + %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32> + %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> + %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32> + %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> + %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32> + %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0]. + %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by + // transpose[1, 2, 0]. + %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32> + + // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double + // transpose[1, 2, 0]. + %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple + // transpose[1, 2, 0]. + %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32> + + // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32 + return %r1, %r2, %r3, %r4 : f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @insert_extract_transpose_3d_2d( +// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, +// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>, +// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32> +func @insert_extract_transpose_3d_2d( + %v: vector<2x3x4xf32>, + %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>) +-> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) +{ + %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32> + %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32> + %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32> + %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32> + %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32> + %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0]. + %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32> + + // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by + // transpose[1, 0, 2]. + %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double + // transpose[1, 0, 2]. + %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32> + + %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> + %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> + %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> + + // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple + // transpose[1, 2, 0]. + %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32> + + // CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]] + // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> + return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> +}