diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -31,6 +31,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -38,6 +39,7 @@ #include "llvm/ADT/bit.h" #include +#include #include #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" @@ -212,6 +214,28 @@ return isDisjointTransferIndices(transferA, transferB); } +// Helper to iterate over n-D vector slice elements. Calculate the next +// `position` in the n-D vector of size `shape`, applying an offset `offsets`. +// Modifies the `position` in place. Returns a failure when `position` becomes +// the end position. +static LogicalResult incSlicePosition(MutableArrayRef position, + ArrayRef shape, + ArrayRef offsets) { + assert(position.size() == shape.size()); + assert(position.size() == offsets.size()); + for (auto [posInDim, dimSize, offsetInDim] : + llvm::reverse(llvm::zip(position, shape, offsets))) { + ++posInDim; + if (posInDim < dimSize + offsetInDim) + return success(); + + // Carry the overflow to the next loop iteration. + posInDim = offsetInDim; + } + + return failure(); +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -2354,12 +2378,88 @@ } }; +// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) -> +// ConstantOp. +class InsertStridedSliceConstantFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // Do not create constants with more than `vectorSizeFoldThreashold` elements, + // unless the source vector constant has a single use. + static constexpr int64_t vectorSizeFoldThreshold = 256; + + LogicalResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + // Return if 'InsertOp' operand is not defined by a compatible vector + // ConstantOp. + TypedValue destVector = op.getDest(); + Attribute vectorDestCst; + if (!matchPattern(destVector, m_Constant(&vectorDestCst))) + return failure(); + + VectorType destTy = destVector.getType(); + if (destTy.isScalable()) + return failure(); + + // Make sure we do not create too many large constants. + if (destTy.getNumElements() > vectorSizeFoldThreshold && + !destVector.hasOneUse()) + return failure(); + + auto denseDest = vectorDestCst.cast(); + + TypedValue sourceValue = op.getSource(); + Attribute sourceCst; + if (!matchPattern(sourceValue, m_Constant(&sourceCst))) + return failure(); + + // TODO: Handle non-unit strides when they become available. + if (op.hasNonUnitStrides()) + return failure(); + + VectorType sliceVecTy = sourceValue.getType(); + ArrayRef sliceShape = sliceVecTy.getShape(); + int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank(); + SmallVector offsets = getI64SubArray(op.getOffsets()); + SmallVector destStrides = computeStrides(destTy.getShape()); + + // Calcualte the destination element indices by enumerating all slice + // positions within the destination and linearizing them. The enumeration + // order is lexicographic which yields a sequence of monotonically + // increasing linearized position indices. + // Because the destination may have higher dimensionality then the slice, + // we keep track of two overlapping sets of positions and offsets. + auto denseSlice = sourceCst.cast(); + auto sliceValuesIt = denseSlice.value_begin(); + auto newValues = llvm::to_vector(denseDest.getValues()); + SmallVector currDestPosition(offsets.begin(), offsets.end()); + MutableArrayRef currSlicePosition( + currDestPosition.begin() + rankDifference, currDestPosition.end()); + ArrayRef sliceOffsets(offsets.begin() + rankDifference, + offsets.end()); + do { + int64_t linearizedPosition = linearize(currDestPosition, destStrides); + assert(linearizedPosition < destTy.getNumElements() && "Invalid index"); + assert(sliceValuesIt != denseSlice.value_end() && + "Invalid slice element"); + newValues[linearizedPosition] = *sliceValuesIt; + ++sliceValuesIt; + } while (succeeded( + incSlicePosition(currSlicePosition, sliceShape, sliceOffsets))); + + auto newAttr = DenseElementsAttr::get(destTy, newValues); + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } +}; + } // namespace void vector::InsertStridedSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add( - context); + results.add(context); } OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { @@ -2817,7 +2917,8 @@ assert(linearizedPosition < sourceVecTy.getNumElements() && "Invalid index"); sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); - } while (succeeded(incPosition(currSlicePosition, sliceShape, offsets))); + } while ( + succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); assert(static_cast(sliceValues.size()) == sliceVecTy.getNumElements() && @@ -2827,28 +2928,6 @@ newAttr); return success(); } - -private: - // Calculate the next `position` in the n-D vector of size `shape`, - // applying an offset `offsets`. Modifies the `position` in place. - // Returns a failure when `position` becomes the end position. - static LogicalResult incPosition(MutableArrayRef position, - ArrayRef shape, - ArrayRef offsets) { - assert(position.size() == shape.size()); - assert(position.size() == offsets.size()); - for (auto [posInDim, dimSize, offsetInDim] : - llvm::reverse(llvm::zip(position, shape, offsets))) { - ++posInDim; - if (posInDim < dimSize + offsetInDim) - return success(); - - // Carry the overflow to the next loop iteration. - posInDim = offsetInDim; - } - - return failure(); - } }; // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1983,6 +1983,55 @@ // ----- +// CHECK-LABEL: func.func @insert_strided_1d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[4, 1, 2]> : vector<3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 1, 4]> : vector<3xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[5, 6, 2]> : vector<3xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<[0, 5, 6]> : vector<3xi32> +// CHECK-DAG: %[[ECST:.*]] = arith.constant dense<[7, 8, 9]> : vector<3xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]] +func.func @insert_strided_1d_constant() -> + (vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>) { + %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %cst_1 = arith.constant dense<4> : vector<1xi32> + %cst_2 = arith.constant dense<[5, 6]> : vector<2xi32> + %cst_3 = arith.constant dense<[7, 8, 9]> : vector<3xi32> + %a = vector.insert_strided_slice %cst_1, %vcst {offsets = [0], strides = [1]} : vector<1xi32> into vector<3xi32> + %b = vector.insert_strided_slice %cst_1, %vcst {offsets = [2], strides = [1]} : vector<1xi32> into vector<3xi32> + %c = vector.insert_strided_slice %cst_2, %vcst {offsets = [0], strides = [1]} : vector<2xi32> into vector<3xi32> + %d = vector.insert_strided_slice %cst_2, %vcst {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32> + %e = vector.insert_strided_slice %cst_3, %vcst {offsets = [0], strides = [1]} : vector<3xi32> into vector<3xi32> + return %a, %b, %c, %d, %e : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32> +} + +// CHECK-LABEL: func.func @insert_strided_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[0, 1\], \[9, 3\], \[4, 5\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[4, 9\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[18, 19\], \[2, 3\], \[4, 5\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[18, 19\], \[4, 5\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[18, 19\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[FCST:.*]] = arith.constant dense<{{\[\[28, 29\], \[38, 39\], \[4, 5\]\]}}> : vector<3x2xi32> +// CHECK-DAG: %[[GCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[28, 29\], \[38, 39\]\]}}> : vector<3x2xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]], %[[FCST]], %[[GCST]] +func.func @insert_strided_2d_constant() -> + (vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>) { + %vcst = arith.constant dense<[[0, 1], [2, 3], [4, 5]]> : vector<3x2xi32> + %cst_1 = arith.constant dense<9> : vector<1xi32> + %cst_2 = arith.constant dense<[18, 19]> : vector<2xi32> + %cst_3 = arith.constant dense<[[28, 29], [38, 39]]> : vector<2x2xi32> + %a = vector.insert_strided_slice %cst_1, %vcst {offsets = [1, 0], strides = [1]} : vector<1xi32> into vector<3x2xi32> + %b = vector.insert_strided_slice %cst_1, %vcst {offsets = [2, 1], strides = [1]} : vector<1xi32> into vector<3x2xi32> + %c = vector.insert_strided_slice %cst_2, %vcst {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32> + %d = vector.insert_strided_slice %cst_2, %vcst {offsets = [1, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32> + %e = vector.insert_strided_slice %cst_2, %vcst {offsets = [2, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32> + %f = vector.insert_strided_slice %cst_3, %vcst {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<3x2xi32> + %g = vector.insert_strided_slice %cst_3, %vcst {offsets = [1, 0], strides = [1, 1]} : vector<2x2xi32> into vector<3x2xi32> + return %a, %b, %c, %d, %e, %f, %g : + vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32> +} + +// ----- + // CHECK-LABEL: func @shuffle_splat // CHECK-SAME: (%[[ARG:.*]]: i32) // CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>