diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -865,6 +865,11 @@ VectorType getDestVectorType() { return dest().getType().cast(); } + bool hasNonUnitStrides() { + return llvm::any_of(strides(), [](Attribute attr) { + return attr.cast().getInt() != 1; + }); + } }]; let hasFolder = 1; @@ -1120,6 +1125,11 @@ static StringRef getStridesAttrName() { return "strides"; } VectorType getVectorType(){ return vector().getType().cast(); } void getOffsets(SmallVectorImpl &results); + bool hasNonUnitStrides() { + return llvm::any_of(strides(), [](Attribute attr) { + return attr.cast().getInt() != 1; + }); + } }]; let hasCanonicalizer = 1; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1204,6 +1204,109 @@ return extractOp.getResult(); } +/// Fold an ExtractOp from ExtractStridedSliceOp. +static Value foldExtractFromExtractStrided(ExtractOp extractOp) { + auto extractStridedSliceOp = + extractOp.vector().getDefiningOp(); + if (!extractStridedSliceOp) + return Value(); + // Return if 'extractStridedSliceOp' has non-unit strides. + if (extractStridedSliceOp.hasNonUnitStrides()) + return Value(); + + // Trim offsets for dimensions fully extracted. + auto sliceOffsets = extractVector(extractStridedSliceOp.offsets()); + while (!sliceOffsets.empty()) { + size_t lastOffset = sliceOffsets.size() - 1; + if (sliceOffsets.back() != 0 || + extractStridedSliceOp.getType().getDimSize(lastOffset) != + extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) + break; + sliceOffsets.pop_back(); + } + unsigned destinationRank = 0; + if (auto vecType = extractOp.getType().dyn_cast()) + destinationRank = vecType.getRank(); + // The dimensions of the result need to be untouched by the + // extractStridedSlice op. + if (destinationRank > + extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) + return Value(); + auto extractedPos = extractVector(extractOp.position()); + assert(extractedPos.size() >= sliceOffsets.size()); + for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) + extractedPos[i] = extractedPos[i] + sliceOffsets[i]; + extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + extractOp->setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(extractedPos)); + return extractOp.getResult(); +} + +/// Fold extract_op fed from a chain of insertStridedSlice ops. +static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { + int64_t destinationRank = op.getType().isa() + ? op.getType().cast().getRank() + : 0; + auto insertOp = op.vector().getDefiningOp(); + while (insertOp) { + int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - + insertOp.getSourceVectorType().getRank(); + if (destinationRank > insertOp.getSourceVectorType().getRank()) + return Value(); + auto insertOffsets = extractVector(insertOp.offsets()); + auto extractOffsets = extractVector(op.position()); + + if (llvm::any_of(insertOp.strides(), [](Attribute attr) { + return attr.cast().getInt() != 1; + })) + return Value(); + bool disjoint = false; + SmallVector offsetDiffs; + for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { + int64_t start = insertOffsets[dim]; + int64_t size = + (dim < insertRankDiff) + ? 1 + : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); + int64_t end = start + size; + int64_t offset = extractOffsets[dim]; + // Check if the start of the extract offset is in the interval inserted. + if (start <= offset && offset < end) { + if (dim >= insertRankDiff) + offsetDiffs.push_back(offset - start); + continue; + } + disjoint = true; + break; + } + // The extract element chunk overlap with the vector inserted. + if (!disjoint) { + // If any of the inner dimensions are only partially inserted we have a + // partial overlap. + int64_t srcRankDiff = + insertOp.getSourceVectorType().getRank() - destinationRank; + for (int64_t i = 0; i < destinationRank; i++) { + if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != + insertOp.getDestVectorType().getDimSize(i + srcRankDiff + + insertRankDiff)) + return Value(); + } + op.vectorMutable().assign(insertOp.source()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(op.getContext()); + op->setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(offsetDiffs)); + return op.getResult(); + } + // If the chunk extracted is disjoint from the chunk inserted, keep + // looking in the insert chain. + insertOp = insertOp.dest().getDefiningOp(); + } + return Value(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (position().empty()) return vector(); @@ -1217,6 +1320,10 @@ return val; if (auto val = foldExtractFromShapeCast(*this)) return val; + if (auto val = foldExtractFromExtractStrided(*this)) + return val; + if (auto val = foldExtractStridedOpFromInsertChain(*this)) + return val; return OpFoldResult(); } @@ -2183,9 +2290,7 @@ if (!constantMaskOp) return failure(); // Return if 'extractStridedSliceOp' has non-unit strides. - if (llvm::any_of(extractStridedSliceOp.strides(), [](Attribute attr) { - return attr.cast().getInt() != 1; - })) + if (extractStridedSliceOp.hasNonUnitStrides()) return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1109,3 +1109,87 @@ %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32> return %0, %1 : vector<7xf32>, i32 } + +// ----- + +// CHECK-LABEL: extract_extract_strided +// CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> +// CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16> +// CHECK: return %[[V]] : vector<4xf16> +func @extract_extract_strided(%arg0: vector<32x16x4xf16>) -> vector<4xf16> { + %1 = vector.extract_strided_slice %arg0 + {offsets = [7, 3], sizes = [10, 8], strides = [1, 1]} : + vector<32x16x4xf16> to vector<10x8x4xf16> + %2 = vector.extract %1[2, 4] : vector<10x8x4xf16> + return %2 : vector<4xf16> +} + +// ----- + +// CHECK-LABEL: extract_insert_strided +// CHECK-SAME: %[[A:.*]]: vector<6x4xf32> +// CHECK: %[[V:.*]] = vector.extract %[[A]][0, 2] : vector<6x4xf32> +// CHECK: return %[[V]] : f32 +func @extract_insert_strided(%a: vector<6x4xf32>, %b: vector<8x16xf32>) + -> f32 { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} + : vector<6x4xf32> into vector<8x16xf32> + %2 = vector.extract %0[2, 4] : vector<8x16xf32> + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: extract_insert_rank_reduce +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: %[[V:.*]] = vector.extract %[[A]][2] : vector<4xf32> +// CHECK: return %[[V]] : f32 +func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>) + -> f32 { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1]} + : vector<4xf32> into vector<8x16xf32> + %2 = vector.extract %0[2, 4] : vector<8x16xf32> + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: extract_insert_negative +// CHECK: vector.insert_strided_slice +// CHECK: vector.extract +func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>) + -> vector<16xf32> { + %0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]} + : vector<2x15xf32> into vector<12x8x16xf32> + %2 = vector.extract %0[4, 2] : vector<12x8x16xf32> + return %2 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: extract_insert_chain +// CHECK-SAME: (%[[A:.*]]: vector<2x16xf32>, %[[B:.*]]: vector<12x8x16xf32>, %[[C:.*]]: vector<2x16xf32>) +// CHECK: %[[V:.*]] = vector.extract %[[C]][0] : vector<2x16xf32> +// CHECK: return %[[V]] : vector<16xf32> +func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %c: vector<2x16xf32>) + -> vector<16xf32> { + %0 = vector.insert_strided_slice %c, %b {offsets = [4, 2, 0], strides = [1, 1]} + : vector<2x16xf32> into vector<12x8x16xf32> + %1 = vector.insert_strided_slice %a, %0 {offsets = [0, 2, 0], strides = [1, 1]} + : vector<2x16xf32> into vector<12x8x16xf32> + %2 = vector.extract %1[4, 2] : vector<12x8x16xf32> + return %2 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: extract_extract_strided2 +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> +// CHECK: return %[[V]] : vector<4xf32> +func @extract_extract_strided2(%A: vector<2x4xf32>) + -> (vector<4xf32>) { + %0 = vector.extract_strided_slice %A {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<2x4xf32> to vector<1x4xf32> + %1 = vector.extract %0[0] : vector<1x4xf32> + return %1 : vector<4xf32> +}