diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1016,6 +1016,7 @@ void getOffsets(SmallVectorImpl &results); }]; let hasCanonicalizer = 1; + let hasFolder = 1; let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1629,6 +1629,81 @@ return success(); } +// When the source of ExtractStrided comes from a chain of InsertStrided ops try +// to use the source o the InsertStrided ops if we can detect that the extracted +// vector is a subset of one of the vector inserted. +static LogicalResult +foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { + // Helper to extract integer out of ArrayAttr. + auto getElement = [](ArrayAttr array, int idx) { + return array[idx].cast().getInt(); + }; + ArrayAttr extractOffsets = op.offsets(); + ArrayAttr extractStrides = op.strides(); + ArrayAttr extractSizes = op.sizes(); + auto insertOp = op.vector().getDefiningOp(); + while (insertOp) { + if (op.getVectorType().getRank() != + insertOp.getSourceVectorType().getRank()) + return failure(); + ArrayAttr insertOffsets = insertOp.offsets(); + ArrayAttr insertStrides = insertOp.strides(); + // If the rank of extract is greater than the rank of insert, we are likely + // extracting a partial chunk of the vector inserted. + if (extractOffsets.size() > insertOffsets.size()) + return failure(); + bool patialoverlap = false; + bool disjoint = false; + SmallVector offsetDiffs; + for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { + if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) + return failure(); + int64_t start = getElement(insertOffsets, dim); + int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); + int64_t offset = getElement(extractOffsets, dim); + int64_t size = getElement(extractSizes, dim); + // Check if the start of the extract offset is in the interval inserted. + if (start <= offset && offset < end) { + // If the extract interval overlaps but is not fully included we may + // have a partial overlap that will prevent any folding. + if (offset + size > end) + patialoverlap = true; + offsetDiffs.push_back(offset - start); + continue; + } + disjoint = true; + break; + } + // The extract element chunk is a subset of the insert element. + if (!disjoint && !patialoverlap) { + op.setOperand(insertOp.source()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(op.getContext()); + op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), + b.getI64ArrayAttr(offsetDiffs)); + return success(); + } + // If the chunk extracted is disjoint from the chunk inserted, keep looking + // in the insert chain. + if (disjoint) + insertOp = insertOp.dest().getDefiningOp(); + else { + // The extracted vector partially overlap the inserted vector, we cannot + // fold. + return failure(); + } + } + return failure(); +} + +OpFoldResult ExtractStridedSliceOp::fold(ArrayRef operands) { + if (getVectorType() == getResult().getType()) + return vector(); + if (succeeded(foldExtractStridedOpFromInsertChain(*this))) + return getResult(); + return {}; +} + void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { populateFromInt64AttrArray(offsets(), results); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -90,6 +90,95 @@ // ----- +// CHECK-LABEL: extract_strided_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>) +// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1> +func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) { + %0 = vector.extract_strided_slice %arg + {offsets = [0, 0], sizes = [4, 3], strides = [1, 1]} + : vector<4x3xi1> to vector<4x3xi1> + return %0 : vector<4x3xi1> +} + +// ----- + +// CHECK-LABEL: extract_strided_fold_insert +// CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32> +// CHECK-NEXT: return %[[ARG]] : vector<4x4xf32> +func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>) + -> (vector<4x4xf32>) { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]} + : vector<8x16xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// ----- + +// Case where the vector inserted is a subset of the vector extracted. +// CHECK-LABEL: extract_strided_fold_insert +// CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32> +// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]} +// CHECK-SAME: : vector<6x4xf32> to vector<4x4xf32> +// CHECK-NEXT: return %[[EXT]] : vector<4x4xf32> +func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>) + -> (vector<4x4xf32>) { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} + : vector<6x4xf32> into vector<8x16xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 2], sizes = [4, 4], strides = [1, 1]} + : vector<8x16xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// ----- + +// Negative test where the extract is not a subset of the element inserted. +// CHECK-LABEL: extract_strided_fold_negative +// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32> +// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]] +// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]} +// CHECK-SAME: : vector<4x4xf32> into vector<8x16xf32> +// CHECK: %[[EXT:.*]] = vector.extract_strided_slice %[[INS]] +// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]} +// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32> +// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32> +func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>) + -> (vector<6x4xf32>) { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]} + : vector<8x16xf32> to vector<6x4xf32> + return %1 : vector<6x4xf32> +} + +// ----- + +// Case where we need to go through 2 level of insert element. +// CHECK-LABEL: extract_strided_fold_insert +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>, +// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]] +// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32> +// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32> +func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>, + %c : vector<1x4xf32>) -> (vector<1x1xf32>) { + %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]} + : vector<1x4xf32> into vector<2x4xf32> + %1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]} + : vector<1x4xf32> into vector<2x4xf32> + %2 = vector.extract_strided_slice %1 + {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} + : vector<2x4xf32> to vector<1x1xf32> + return %2 : vector<1x1xf32> +} + +// ----- + // CHECK-LABEL: transpose_1D_identity // CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>) func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {