diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -30,9 +30,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" + +#include #include #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" @@ -2670,28 +2674,117 @@ }; // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. -class StridedSliceConstantFolder final +class StridedSliceSplatConstantFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { - // Return if 'extractStridedSliceOp' operand is not defined by a + // Return if 'ExtractStridedSliceOp' operand is not defined by a splat // ConstantOp. - auto constantOp = - extractStridedSliceOp.getVector().getDefiningOp(); - if (!constantOp) + Value sourceVector = extractStridedSliceOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) return failure(); - auto dense = constantOp.getValue().dyn_cast(); - if (!dense) + + auto splat = vectorCst.dyn_cast(); + if (!splat) + return failure(); + + auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), + splat.getSplatValue()); + rewriter.replaceOpWithNewOp(extractStridedSliceOp, + newAttr); + return success(); + } +}; + +// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> +// ConstantOp. +class StridedSliceNonSplatConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, + PatternRewriter &rewriter) const override { + // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat + // ConstantOp. + Value sourceVector = extractStridedSliceOp.getVector(); + Attribute vectorCst; + if (!matchPattern(sourceVector, m_Constant(&vectorCst))) + return failure(); + + // The splat case is handled by `StridedSliceSplatConstantFolder`. + auto dense = vectorCst.dyn_cast(); + if (!dense || dense.isSplat()) + return failure(); + + // TODO: Handle non-unit strides when they become available. + if (extractStridedSliceOp.hasNonUnitStrides()) return failure(); - auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), - dense.getSplatValue()); + + auto sourceVecTy = sourceVector.getType().cast(); + ArrayRef sourceShape = sourceVecTy.getShape(); + SmallVector sourceStrides = computeStrides(sourceShape); + + VectorType sliceVecTy = extractStridedSliceOp.getType(); + ArrayRef sliceShape = sliceVecTy.getShape(); + int64_t sliceRank = sliceVecTy.getRank(); + + // Expand offsets and sizes to match the vector rank. + SmallVector offsets(sliceRank, 0); + llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()), + offsets.begin()); + + SmallVector sizes(sourceShape.begin(), sourceShape.end()); + llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); + + // Calcualte the slice elements by enumerating all slice positions and + // linearizing them. The enumeration order is lexicographic which yields a + // sequence of monotonically increasing linearized position indices. + auto denseValuesBegin = dense.value_begin(); + SmallVector sliceValues; + sliceValues.reserve(sliceVecTy.getNumElements()); + SmallVector currSlicePosition(offsets.begin(), offsets.end()); + do { + int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); + assert(linearizedPosition < sourceVecTy.getNumElements() && + "Invalid index"); + sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); + } while (succeeded(incPosition(currSlicePosition, sliceShape, offsets))); + + assert(static_cast(sliceValues.size()) == + sliceVecTy.getNumElements() && + "Invalid number of slice elements"); + auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); rewriter.replaceOpWithNewOp(extractStridedSliceOp, newAttr); return success(); } + +private: + // Calculate the next `position` in the n-D vector of size `shape`, + // applying an offset `offsets`. Modifies the `position` in place. + // Returns a failure when `position` becomes the end position. + static LogicalResult incPosition(MutableArrayRef position, + ArrayRef shape, + ArrayRef offsets) { + assert(position.size() == shape.size()); + assert(position.size() == offsets.size()); + for (auto [posInDim, dimSize, offsetInDim] : + llvm::reverse(llvm::zip(position, shape, offsets))) { + ++posInDim; + if (posInDim < dimSize + offsetInDim) + return success(); + + // Carry the overflow to the next loop iteration. + posInDim = offsetInDim; + } + + return failure(); + } }; // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to @@ -2760,8 +2853,9 @@ RewritePatternSet &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1488,6 +1488,63 @@ // ----- +// CHECK-LABEL: func.func @extract_strided_slice_1d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[1, 2]> : vector<2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<2> : vector<1xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<2xi32>, vector<1xi32> +func.func @extract_strided_slice_1d_constant() -> (vector<3xi32>, vector<2xi32>, vector<1xi32>) { + %cst = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %a = vector.extract_strided_slice %cst + {offsets = [0], sizes = [3], strides = [1]} : vector<3xi32> to vector<3xi32> + %b = vector.extract_strided_slice %cst + {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [2], sizes = [1], strides = [1]} : vector<3xi32> to vector<1xi32> + return %a, %b, %c : vector<3xi32>, vector<2xi32>, vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<1x1xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[4, 5\]\]}}> : vector<1x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[1, 2\], \[4, 5\]\]}}> : vector<2x2xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32> +func.func @extract_strided_slice_2d_constant() -> (vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32>) { + %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32> + %a = vector.extract_strided_slice %cst + {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x3xi32> to vector<1x1xi32> + %b = vector.extract_strided_slice %cst + {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<2x3xi32> to vector<2x2xi32> + return %a, %b, %c : vector<1x1xi32>, vector<1x2xi32>, vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_3d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[\[8, 9\], \[10, 11\]\]\]}}> : vector<1x2x2xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[\[2, 3\]\]\]}}> : vector<1x1x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[\[6, 7\]\], \[\[10, 11\]\]\]}}> : vector<2x1x2xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<11> : vector<1x1x1xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] +func.func @extract_strided_slice_3d_constant() -> (vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32>) { + %cst = arith.constant dense<[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]]> : vector<3x2x2xi32> + %a = vector.extract_strided_slice %cst + {offsets = [2], sizes = [1], strides = [1]} : vector<3x2x2xi32> to vector<1x2x2xi32> + %b = vector.extract_strided_slice %cst + {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<3x2x2xi32> to vector<1x1x2xi32> + %c = vector.extract_strided_slice %cst + {offsets = [1, 1, 0], sizes = [2, 1, 2], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<2x1x2xi32> + %d = vector.extract_strided_slice %cst + {offsets = [2, 1, 1], sizes = [1, 1, 1], strides = [1, 1, 1]} : vector<3x2x2xi32> to vector<1x1x1xi32> + return %a, %b, %c, %d : vector<1x2x2xi32>, vector<1x1x2xi32>, vector<2x1x2xi32>, vector<1x1x1xi32> +} + +// ----- + // CHECK-LABEL: extract_extract_strided // CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16> // CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>