diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -18,11 +18,15 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -30,14 +34,22 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" +#include +#include +#include #include #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" // Pull in all enum type and utility function definitions. #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" +#include "mlir/Support/LogicalResult.h" using namespace mlir; using namespace mlir::vector; @@ -2690,30 +2702,131 @@ }; // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. -class StridedSliceConstantFolder final +class StridedSliceSplatConstantFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { - // Return if 'extractStridedSliceOp' operand is not defined by a + // Return if 'ExtractStridedSliceOp' operand is not defined by a splat // ConstantOp. - auto constantOp = - extractStridedSliceOp.getVector().getDefiningOp(); - if (!constantOp) + Value sourceVector = extractStridedSliceOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto dense = constantOp.getValue().dyn_cast(); - if (!dense) + + auto splat = vectorCst.dyn_cast(); + if (!splat) return failure(); - auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), - dense.getSplatValue()); + + auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), + splat.getSplatValue()); rewriter.replaceOpWithNewOp(extractStridedSliceOp, newAttr); return success(); } }; +// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> +// ConstantOp. +class StridedSliceNonSplatConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, + PatternRewriter &rewriter) const override { + // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat + // ConstantOp. + Value sourceVector = extractStridedSliceOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) + return failure(); + + // The splat case is handled by `StridedSliceSplatConstantFolder`. + auto dense = vectorCst.dyn_cast(); + if (!dense || dense.isSplat()) + return failure(); + + // TODO: Handle non-unit strides when they become available. + if (extractStridedSliceOp.hasNonUnitStrides()) + return failure(); + + auto sourceVecTy = sourceVector.getType().cast(); + ArrayRef sourceShape = sourceVecTy.getShape(); + SmallVector sourceStrides = computeStrides(sourceShape); + + VectorType sliceVecTy = extractStridedSliceOp.getType(); + int64_t sliceRank = sliceVecTy.getRank(); + + // Expand offsets and sizes to match the vector rank. + SmallVector offsets(sliceRank, 0); + llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()), + offsets.begin()); + + SmallVector sizes(sourceShape.begin(), sourceShape.end()); + llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); + + // Bound the search for sub-shape slice values to the first and last + // linearized element index. + int64_t beginLinearIdx = calculateBeginLinearIdx(sourceStrides, offsets); + int64_t endLinearIdx = calculateEndLinearIdx(sourceStrides, offsets, sizes); + + assert((endLinearIdx - beginLinearIdx) >= sliceVecTy.getNumElements() && + "Invalid slice index linearization"); + + SmallVector allValues( + dense.value_begin() + beginLinearIdx, + dense.value_begin() + endLinearIdx); + llvm::SmallVector sliceValues; + sliceValues.reserve(sliceVecTy.getNumElements()); + + // Collect elements within the extracted slice. + for (auto [idx, elementValue] : + llvm::zip(llvm::seq(beginLinearIdx, endLinearIdx), allValues)) + if (isInSlice(delinearize(sourceStrides, idx), offsets, sizes)) + sliceValues.push_back(elementValue); + + auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); + rewriter.replaceOpWithNewOp(extractStridedSliceOp, + newAttr); + return success(); + } + +private: + // Returns true iff the point `elementPosition` is within the shape slice + // defined by `offsets` and `sizes`. + static bool isInSlice(ArrayRef elementPosition, + ArrayRef offsets, ArrayRef sizes) { + return llvm::all_of( + llvm::zip(elementPosition, offsets, sizes), [](auto posOffsetSize) { + auto [posInDim, offset, size] = posOffsetSize; + return (posInDim >= offset) && (posInDim < offset + size); + }); + } + + // Returns the first index of the slice defined by `offsets` in linear space + // defined by `strides`. + static int64_t calculateBeginLinearIdx(ArrayRef strides, + ArrayRef offsets) { + return linearize(offsets, strides); + } + + // Returns the end index of the slice defined by `offsets` and `sizes` in + // linear space defined by `strides`. + static int64_t calculateEndLinearIdx(ArrayRef strides, + ArrayRef offsets, + ArrayRef sizes) { + auto lastOffsets = llvm::to_vector_of( + llvm::map_range(llvm::zip(offsets, sizes), [](auto offsetSizePair) { + auto [offset, size] = offsetSizePair; + return offset + size - 1; + })); + return linearize(lastOffsets, strides) + 1; + } +}; + // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to // BroadcastOp(ExtractStrideSliceOp). class StridedSliceBroadcast final @@ -2780,8 +2893,9 @@ RewritePatternSet &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1488,6 +1488,63 @@ // ----- +// CHECK-LABEL: func.func @extract_strided_slice_1d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[1, 2]> : vector<2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<2> : vector<1xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<2xi32>, vector<1xi32> +func.func @extract_strided_slice_1d_constant() -> (vector<3xi32>, vector<2xi32>, vector<1xi32>) { + %cst = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %a = vector.extract_strided_slice %cst + {offsets = [0], sizes = [3], strides = [1]} : vector<3xi32> to vector<3xi32> + %b = vector.extract_strided_slice %cst + {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [2], sizes = [1], strides = [1]} : vector<3xi32> to vector<1xi32> + return %a, %b, %c : vector<3xi32>, vector<2xi32>, vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<1x1xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[4, 5\]\]}}> : vector<1x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[1, 2\], \[4, 5\]\]}}> : vector<2x2xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32> +func.func @extract_strided_slice_2d_constant() -> (vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>) { + %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32> + %a = vector.extract_strided_slice %cst + {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x3xi32> to vector<1x1xi32> + %b = vector.extract_strided_slice %cst + {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<2x3xi32> to vector<2x2xi32> + return %a, %b, %c : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_3d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[\[8, 9\], \[10, 11\]\]\]}}> : vector<1x2x2xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[\[2, 3\]\]\]}}> : vector<1x1x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[\[6, 7\]\], \[\[10, 11\]\]\]}}> : vector<2x1x2xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<11> : vector<1x1x1xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] +func.func @extract_strided_slice_3d_constant() -> (vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>) { + %cst = arith.constant dense<[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]> : vector<3x2x2xi32> + %a = vector.extract_strided_slice %cst + {offsets = [2], sizes = [1], strides = [1]} : vector<3x2x2xi32> to vector<1x2x2xi32> + %b = vector.extract_strided_slice %cst + {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<3x2x2xi32> to vector<1x1x2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [1, 1, 0], sizes = [2, 1, 2], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<2x1x2xi32> + %d = vector.extract_strided_slice %cst + {offsets = [2, 1, 1], sizes = [1, 1, 1], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<1x1x1xi32> + return %a, %b, %c, %d : vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32> +} + +// ----- + // CHECK-LABEL: extract_extract_strided // CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> // CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>