diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -946,7 +946,7 @@ } template -static SmallVector extractVector(ArrayAttr arrayAttr) { +static SmallVector extractVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>(llvm::map_range( arrayAttr.getAsRange(), [](IntegerAttr attr) { return static_cast(attr.getInt()); })); @@ -960,12 +960,12 @@ SmallVector globalPosition; ExtractOp currentOp = extractOp; - auto extractedPos = extractVector(currentOp.position()); - globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); + auto extrPos = extractVector(currentOp.position()); + globalPosition.append(extrPos.rbegin(), extrPos.rend()); while (ExtractOp nextOp = currentOp.vector().getDefiningOp()) { currentOp = nextOp; - auto extractedPos = extractVector(currentOp.position()); - globalPosition.append(extractedPos.rbegin(), extractedPos.rend()); + auto extrPos = extractVector(currentOp.position()); + globalPosition.append(extrPos.rbegin(), extrPos.rend()); } extractOp.setOperand(currentOp.vector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. @@ -976,144 +976,219 @@ return success(); } -/// Fold the result of an ExtractOp in place when it comes from a TransposeOp. -static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) { - auto transposeOp = extractOp.vector().getDefiningOp(); - if (!transposeOp) - return failure(); +namespace { +/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. +/// Walk back a chain of InsertOp/TransposeOp until we hit a match. +/// Compose TransposeOp permutations as we walk back. +/// This helper class keeps an updated extraction position `extractPosition` +/// with extra trailing sentinels. +/// The sentinels encode the internal transposition status of the result vector. +/// As we iterate, extractPosition is permuted and updated. +class ExtractFromInsertTransposeChainState { +public: + ExtractFromInsertTransposeChainState(ExtractOp e); + + /// Iterate over producing insert and transpose ops until we find a fold. + Value fold(); + +private: + /// Return true if the vector at position `a` is contained within the vector + /// at position `b`. Under insert/extract semantics, this is the same as `a` + /// is a prefix of `b`. + template + bool isContainedWithin(const ContainerA &a, const ContainerB &b) { + return a.size() <= b.size() && + std::equal(a.begin(), a.begin() + a.size(), b.begin()); + } - auto permutation = extractVector(transposeOp.transp()); - auto extractedPos = extractVector(extractOp.position()); + /// Return true if the vector at position `a` intersects the vector at + /// position `b`. Under insert/extract semantics, this is the same as equality + /// of all entries of `a` that are >=0 with the corresponding entries of b. + /// Comparison is on the common prefix (i.e. zip). + template + bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { + for (auto it : llvm::zip(a, b)) { + if (std::get<0>(it) < 0 || std::get<0>(it) < 0) + continue; + if (std::get<0>(it) != std::get<1>(it)) + return false; + } + return true; + } + + /// Folding is only possible in the absence of an internal permutation in the + /// result vector. + bool canFold() { + return (sentinels == + makeArrayRef(extractPosition).drop_front(extractedRank)); + } - // If transposition permutation is larger than the ExtractOp, all minor - // dimensions must be an identity for folding to occur. If not, individual - // elements within the extracted value are transposed and this is not just a - // simple folding. - unsigned minorRank = permutation.size() - extractedPos.size(); - MLIRContext *ctx = extractOp.getContext(); - AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx); - AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); - if (minorMap && !minorMap.isMinorIdentity()) + // Helper to get the next defining op of interest. + void updateStateForNextIteration(Value v) { + nextInsertOp = v.getDefiningOp(); + nextTransposeOp = v.getDefiningOp(); + }; + + // Case 1. If we hit a transpose, just compose the map and iterate. + // Invariant: insert + transpose do not change rank, we can always compose. + LogicalResult handleTransposeOp(); + + // Case 2: the insert position matches extractPosition exactly, early return. + LogicalResult handleInsertOpWithMatchingPos(Value &res); + + /// Case 3: if the insert position is a prefix of extractPosition, extract a + /// portion of the source of the insert. + /// Example: + /// ``` + /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> + /// // extractPosition == [1, 2, 3] + /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> + /// // can fold to vector.extract %source[0, 3] + /// %ext = vector.extract %source[3]: vector<5x6> + /// ``` + /// To traverse through %source, we need to set the leading dims to 0 and + /// drop the extra leading dims. + /// This method updates the internal state. + LogicalResult handleInsertOpWithPrefixPos(Value &res); + + /// Try to fold in place to extract(source, extractPosition) and return the + /// folded result. Return null if folding is not possible (e.g. due to an + /// internal tranposition in the result). + Value tryToFoldExtractOpInPlace(Value source); + + ExtractOp extractOp; + int64_t vectorRank; + int64_t extractedRank; + + InsertOp nextInsertOp; + TransposeOp nextTransposeOp; + + /// Sentinel values that encode the internal permutation status of the result. + /// They are set to (-1, ... , -k) at the beginning and appended to + /// `extractPosition`. + /// In the end, the tail of `extractPosition` must be exactly `sentinels` to + /// ensure that there is no internal transposition. + /// Internal transposition cannot be accounted for with a folding pattern. + // TODO: We could relax the internal transposition with an extra transposition + // operation in a future canonicalizer. + SmallVector sentinels; + SmallVector extractPosition; +}; +} // namespace + +ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( + ExtractOp e) + : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), + extractedRank(extractOp.position().size()) { + assert(vectorRank >= extractedRank && "extracted pos overflow"); + sentinels.reserve(vectorRank - extractedRank); + for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) + sentinels.push_back(-(i + 1)); + extractPosition = extractVector(extractOp.position()); + llvm::append_range(extractPosition, sentinels); +} + +// Case 1. If we hit a transpose, just compose the map and iterate. +// Invariant: insert + transpose do not change rank, we can always compose. +LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { + if (!nextTransposeOp) return failure(); + auto permutation = extractVector(nextTransposeOp.transp()); + AffineMap m = inversePermutation( + AffineMap::getPermutationMap(permutation, extractOp.getContext())); + extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); + return success(); +} - // %1 = transpose %0[x, y, z] : vector - // %2 = extract %1[u, v] : vector<..xf32> - // may turn into: - // %2 = extract %0[w, x] : vector<..xf32> - // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and - // -1 denotes the inverse. - permutationMap = permutationMap.getMajorSubMap(extractedPos.size()); - // The major submap has fewer results but the same number of dims. To compose - // cleanly, we need to drop dims to form a "square matrix". This is possible - // because: - // (a) this is a permutation map and - // (b) the minor map has already been checked to be identity. - // Therefore, the major map cannot contain dims of position greater or equal - // than the number of results. - assert(llvm::all_of(permutationMap.getResults(), - [&](AffineExpr e) { - auto dim = e.dyn_cast(); - return dim && dim.getPosition() < - permutationMap.getNumResults(); - }) && - "Unexpected map results depend on higher rank positions"); - // Project on the first domain dimensions to allow composition. - permutationMap = AffineMap::get(permutationMap.getNumResults(), 0, - permutationMap.getResults(), ctx); - - extractOp.setOperand(transposeOp.vector()); - // Compose the inverse permutation map with the extractedPos. - auto newExtractedPos = - inversePermutation(permutationMap).compose(extractedPos); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), - b.getI64ArrayAttr(newExtractedPos)); +// Case 2: the insert position matches extractPosition exactly, early return. +LogicalResult +ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( + Value &res) { + auto insertedPos = extractVector(nextInsertOp.position()); + if (makeArrayRef(insertedPos) != + llvm::makeArrayRef(extractPosition).take_front(extractedRank)) + return failure(); + // Case 2.a. early-exit fold. + res = nextInsertOp.source(); + // Case 2.b. if internal transposition is present, canFold will be false. + return success(); +} +/// Case 3: if inserted position is a prefix of extractPosition, +/// extract a portion of the source of the insertion. +/// This method updates the internal state. +LogicalResult +ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { + auto insertedPos = extractVector(nextInsertOp.position()); + if (!isContainedWithin(insertedPos, extractPosition)) + return failure(); + // Set leading dims to zero. + std::fill_n(extractPosition.begin(), insertedPos.size(), 0); + // Drop extra leading dims. + extractPosition.erase(extractPosition.begin(), + extractPosition.begin() + insertedPos.size()); + extractedRank = extractPosition.size() - sentinels.size(); + // Case 3.a. early-exit fold (break and delegate to post-while path). + res = nextInsertOp.source(); + // Case 3.b. if internal transposition is present, canFold will be false. return success(); } -/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The -/// result is always the input to some InsertOp. -static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) { - MLIRContext *context = extractOp.getContext(); - AffineMap permutationMap; - auto extractedPos = extractVector(extractOp.position()); - // Walk back a chain of InsertOp/TransposeOp until we hit a match. - // Compose TransposeOp permutations as we walk back. - auto insertOp = extractOp.vector().getDefiningOp(); - auto transposeOp = extractOp.vector().getDefiningOp(); - while (insertOp || transposeOp) { - if (transposeOp) { - // If it is transposed, compose the map and iterate. - auto permutation = extractVector(transposeOp.transp()); - AffineMap newMap = AffineMap::getPermutationMap(permutation, context); - if (!permutationMap) - permutationMap = newMap; - else if (newMap.getNumInputs() != permutationMap.getNumResults()) - return Value(); - else - permutationMap = newMap.compose(permutationMap); - // Compute insert/transpose for the next iteration. - Value transposed = transposeOp.vector(); - insertOp = transposed.getDefiningOp(); - transposeOp = transposed.getDefiningOp(); - continue; - } +/// Try to fold in place to extract(source, extractPosition) and return the +/// folded result. Return null if folding is not possible (e.g. due to an +/// internal tranposition in the result). +Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( + Value source) { + // If we can't fold (either internal transposition, or nothing to fold), bail. + bool nothingToFold = (source == extractOp.vector()); + if (nothingToFold || !canFold()) + return Value(); + // Otherwise, fold by updating the op inplace and return its result. + OpBuilder b(extractOp.getContext()); + extractOp->setAttr( + extractOp.positionAttrName(), + b.getI64ArrayAttr( + makeArrayRef(extractPosition).take_front(extractedRank))); + extractOp.vectorMutable().assign(source); + return extractOp.getResult(); +} - assert(insertOp); - Value insertionDest = insertOp.dest(); - // If it is inserted into, either the position matches and we have a - // successful folding; or we iterate until we run out of - // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector` - // produces a new vector with 1 modified value/slice in exactly the static - // position we need to match. - auto insertedPos = extractVector(insertOp.position()); - // Trivial permutations are solved with position equality checks. - if (!permutationMap || permutationMap.isIdentity()) { - if (extractedPos == insertedPos) - return insertOp.source(); - // Fallthrough: if the position does not match, just skip to the next - // producing `vector.insert` / `vector.transpose`. - // Compute insert/transpose for the next iteration. - insertOp = insertionDest.getDefiningOp(); - transposeOp = insertionDest.getDefiningOp(); +/// Iterate over producing insert and transpose ops until we find a fold. +Value ExtractFromInsertTransposeChainState::fold() { + Value valueToExtractFrom = extractOp.vector(); + updateStateForNextIteration(valueToExtractFrom); + while (nextInsertOp || nextTransposeOp) { + // Case 1. If we hit a transpose, just compose the map and iterate. + // Invariant: insert + transpose do not change rank, we can always compose. + if (succeeded(handleTransposeOp())) { + valueToExtractFrom = nextTransposeOp.vector(); + updateStateForNextIteration(valueToExtractFrom); continue; } - // More advanced permutations require application of the permutation. - // However, the rank of `insertedPos` may be different from that of the - // `permutationMap`. To support such case, we need to: - // 1. apply on the `insertedPos.size()` major dimensions - // 2. check the other dimensions of the permutation form a minor identity. - assert(permutationMap.isPermutation() && "expected a permutation"); - if (insertedPos.size() == extractedPos.size()) { - bool fold = true; - for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) { - auto pos = permutationMap.getDimPosition(idx); - if (pos >= sz || insertedPos[pos] != extractedPos[idx]) { - fold = false; - break; - } - } - if (fold) { - assert(permutationMap.getNumResults() >= insertedPos.size() && - "expected map of rank larger than insert indexing"); - unsigned minorRank = - permutationMap.getNumResults() - insertedPos.size(); - AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); - if (!minorMap || minorMap.isMinorIdentity()) - return insertOp.source(); - } - } + Value result; + // Case 2: the position match exactly. + if (succeeded(handleInsertOpWithMatchingPos(result))) + return result; - // If we haven't found a match, just continue to the next producing - // `vector.insert` / `vector.transpose`. - // Compute insert/transpose for the next iteration. - insertOp = insertionDest.getDefiningOp(); - transposeOp = insertionDest.getDefiningOp(); + // Case 3: if the inserted position is a prefix of extractPosition, we can + // just extract a portion of the source of the insert. + if (succeeded(handleInsertOpWithPrefixPos(result))) + return tryToFoldExtractOpInPlace(result); + + // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel + // values. This is a more difficult case and we bail. + auto insertedPos = extractVector(nextInsertOp.position()); + if (isContainedWithin(extractPosition, insertedPos) || + intersectsWhereNonNegative(extractPosition, insertedPos)) + return Value(); + + // Case 5: No intersection, we forward the extract to insertOp.dest(). + valueToExtractFrom = nextInsertOp.dest(); + updateStateForNextIteration(valueToExtractFrom); } - return Value(); + // If after all this we can fold, go for it. + return tryToFoldExtractOpInPlace(valueToExtractFrom); } /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. @@ -1209,14 +1284,12 @@ return vector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); - if (succeeded(foldExtractOpFromTranspose(*this))) - return getResult(); - if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) - return val; - if (auto val = foldExtractFromBroadcast(*this)) - return val; - if (auto val = foldExtractFromShapeCast(*this)) - return val; + if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) + return res; + if (auto res = foldExtractFromBroadcast(*this)) + return res; + if (auto res = foldExtractFromShapeCast(*this)) + return res; return OpFoldResult(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -316,87 +316,103 @@ // ----- -// CHECK-LABEL: func @insert_extract_transpose_3d( -// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, -// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32, -// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32, -// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32, -// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32 -func @insert_extract_transpose_3d( - %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32) --> (f32, f32, f32, f32) -{ - %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32> - %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32> - %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32> - %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32> - %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> - %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32> - %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> - %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32> - %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> - - // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0]. - %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32> - - // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by - // transpose[1, 2, 0]. - %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32> - - // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double - // transpose[1, 2, 0]. - %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32> - - // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple - // transpose[1, 2, 0]. - %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32> - - // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32 - return %r1, %r2, %r3, %r4 : f32, f32, f32, f32 -} - -// ----- - -// CHECK-LABEL: func @insert_extract_transpose_3d_2d( -// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>, -// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>, -// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>, -// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>, -// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32> -func @insert_extract_transpose_3d_2d( - %v: vector<2x3x4xf32>, - %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>) --> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) -{ - %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32> - %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32> - %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32> - %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32> - %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32> - %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> - - // Expected %f2 from %2 = vector.insert %f2, %1[1, 0]. - %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32> - - // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by - // transpose[1, 0, 2]. - %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32> - - // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double - // transpose[1, 0, 2]. - %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32> - - %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32> - %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32> - %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32> +// CHECK-LABEL: insert_extract_chain +// CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32> +// CHECK-SAME: %[[V34:[a-zA-Z0-9]*]]: vector<3x4xf32> +// CHECK-SAME: %[[V4:[a-zA-Z0-9]*]]: vector<4xf32> +func @insert_extract_chain(%v234: vector<2x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>) + -> (vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>) { + // CHECK-NEXT: %[[A34:.*]] = vector.insert + %A34 = vector.insert %v34, %v234[0]: vector<3x4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: %[[B34:.*]] = vector.insert + %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: %[[A4:.*]] = vector.insert + %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<2x3x4xf32> + // CHECK-NEXT: %[[B4:.*]] = vector.insert + %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<2x3x4xf32> + + // Case 2.a. [1, 1] == insertpos ([1, 1]) + // Match %A4 insertionpos and fold to its source(i.e. %V4). + %r0 = vector.extract %B4[1, 1]: vector<2x3x4xf32> + + // Case 3.a. insertpos ([1]) is a prefix of [1, 0]. + // Traverse %B34 to its source(i.e. %V34@[*0*]). + // CHECK-NEXT: %[[R1:.*]] = vector.extract %[[V34]][0] + %r1 = vector.extract %B34[1, 0]: vector<2x3x4xf32> + + // Case 4. [1] is a prefix of insertpos ([1, 1]). + // Cannot traverse %B4. + // CHECK-NEXT: %[[R2:.*]] = vector.extract %[[B4]][1] + %r2 = vector.extract %B4[1]: vector<2x3x4xf32> + + // Case 5. [0] is disjoint from insertpos ([1, 1]). + // Traverse %B4 to its dest(i.e. %A4@[0]). + // Traverse %A4 to its dest(i.e. %B34@[0]). + // Traverse %B34 to its dest(i.e. %A34@[0]). + // Match %A34 insertionpos and fold to its source(i.e. %V34). + %r3 = vector.extract %B4[0]: vector<2x3x4xf32> + + // CHECK: return %[[V4]], %[[R1]], %[[R2]], %[[V34]] + return %r0, %r1, %r2, %r3: + vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32> +} - // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple - // transpose[1, 2, 0]. - %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32> +// ----- - // CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]] - // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> - return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> +// CHECK-LABEL: func @insert_extract_transpose_3d( +// CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32> +func @insert_extract_transpose_3d( + %v234: vector<2x3x4xf32>, %v43: vector<4x3xf32>, %f0: f32) + -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32>) { + + %a432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> + %b432 = vector.insert %f0, %a432[0, 0, 1] : f32 into vector<4x3x2xf32> + %c234 = vector.transpose %b432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32> + // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1] + // Case 5. %b432 = insert [0,0,1] (inter([.,2,1], [.,0,1]) == 0) prop to %v432 + // Case 1. %a432 = transpose [2,1,0] posWithSentinels [-1,2,1] -> [1,2,-1] + // can extract directly from %v234, the rest folds. + // CHECK: %[[R0:.*]] = vector.extract %[[V234]][1, 2] + %r0 = vector.extract %c234[1, 2] : vector<2x3x4xf32> + + // CHECK-NEXT: vector.transpose + // CHECK-NEXT: vector.insert + // CHECK-NEXT: %[[F234:.*]] = vector.transpose + %d432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> + %e432 = vector.insert %f0, %d432[0, 2, 1] : f32 into vector<4x3x2xf32> + %f234 = vector.transpose %e432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32> + // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1] + // Case 4. %b432 = insert [0,0,1] (inter([.,2,1], [.,2,1]) != 0) + // Bail, cannot do better than the current. + // CHECK: %[[R1:.*]] = vector.extract %[[F234]] + %r1 = vector.extract %f234[1, 2] : vector<2x3x4xf32> + + // CHECK-NEXT: vector.transpose + // CHECK-NEXT: vector.insert + // CHECK-NEXT: %[[H234:.*]] = vector.transpose + %g243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32> + %h243 = vector.insert %v43, %g243[0] : vector<4x3xf32> into vector<2x4x3xf32> + %i234 = vector.transpose %h243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32> + // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1] + // Case 3.b. %b432 = insert [0] is prefix of [0,.,.] but internal transpose. + // Bail, cannot do better than the current. + // CHECK: %[[R2:.*]] = vector.extract %[[H234]][0, 1] + %r2 = vector.extract %i234[0, 1] : vector<2x3x4xf32> + + // CHECK-NEXT: vector.transpose + // CHECK-NEXT: vector.insert + // CHECK-NEXT: %[[K234:.*]] = vector.transpose + %j243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32> + %k243 = vector.insert %v43, %j243[0] : vector<4x3xf32> into vector<2x4x3xf32> + %l234 = vector.transpose %k243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32> + // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1] + // Case 2.b. %b432 = insert [0] == [0,.,.] but internal transpose. + // Bail, cannot do better than the current. + // CHECK: %[[R3:.*]] = vector.extract %[[K234]][0] + %r3 = vector.extract %l234[0] : vector<2x3x4xf32> + + // CHECK-NEXT: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] + return %r0, %r1, %r2, %r3: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32> } // -----