diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1635,11 +1635,10 @@ if (!dense || dense.isSplat()) return failure(); - // Calculate the linearized position of the continous chunk of elements to + // Calculate the linearized position of the continuous chunk of elements to // extract. llvm::SmallVector completePositions(vecTy.getRank(), 0); - llvm::copy(getI64SubArray(extractOp.getPosition()), - completePositions.begin()); + copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); auto denseValuesBegin = dense.value_begin() + elemBeginPosition; @@ -2084,11 +2083,68 @@ } }; +// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp. +class InsertOpConstantFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // Do not create constants with more than `vectorSizeFoldThreashold` elements, + // unless the source vector constant has a single use. + static constexpr int64_t vectorSizeFoldThreshold = 256; + + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter &rewriter) const override { + // Return if 'InsertOp' operand is not defined by a compatible vector + // ConstantOp. + TypedValue destVector = op.getDest(); + Attribute vectorDestCst; + if (!matchPattern(destVector, m_Constant(&vectorDestCst))) + return failure(); + + VectorType destTy = destVector.getType(); + if (destTy.isScalable()) + return failure(); + + // Make sure we do not create too many large constants. + if (destTy.getNumElements() > vectorSizeFoldThreshold && + !destVector.hasOneUse()) + return failure(); + + auto denseDest = vectorDestCst.cast(); + + Value sourceValue = op.getSource(); + Attribute sourceCst; + if (!matchPattern(sourceValue, m_Constant(&sourceCst))) + return failure(); + + // Calculate the linearized position of the continuous chunk of elements to + // insert. + llvm::SmallVector completePositions(destTy.getRank(), 0); + copy(getI64SubArray(op.getPosition()), completePositions.begin()); + int64_t insertBeginPosition = + linearize(completePositions, computeStrides(destTy.getShape())); + + SmallVector insertedValues; + if (auto denseSource = sourceCst.dyn_cast()) + llvm::append_range(insertedValues, denseSource.getValues()); + else + insertedValues.push_back(sourceCst); + + auto allValues = llvm::to_vector(denseDest.getValues()); + copy(insertedValues, allValues.begin() + insertBeginPosition); + auto newAttr = DenseElementsAttr::get(destTy, allValues); + + rewriter.replaceOpWithNewOp(op, newAttr); + return success(); + } +}; + } // namespace void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } // Eliminates insert operations that produce values identical to their source @@ -2744,13 +2800,12 @@ // Expand offsets and sizes to match the vector rank. SmallVector offsets(sliceRank, 0); - llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()), - offsets.begin()); + copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); SmallVector sizes(sourceShape.begin(), sourceShape.end()); - llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); + copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); - // Calcualte the slice elements by enumerating all slice positions and + // Calculate the slice elements by enumerating all slice positions and // linearizing them. The enumeration order is lexicographic which yields a // sequence of monotonically increasing linearized position indices. auto denseValuesBegin = dense.value_begin(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1795,6 +1795,64 @@ // ----- +// CHECK-LABEL: func.func @insert_1d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[0, 1, 9]> : vector<3xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<3xi32>, vector<3xi32> +func.func @insert_1d_constant() -> (vector<3xi32>, vector<3xi32>, vector<3xi32>) { + %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %icst = arith.constant 9 : i32 + %a = vector.insert %icst, %vcst[0] : i32 into vector<3xi32> + %b = vector.insert %icst, %vcst[1] : i32 into vector<3xi32> + %c = vector.insert %icst, %vcst[2] : i32 into vector<3xi32> + return %a, %b, %c : vector<3xi32>, vector<3xi32>, vector<3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @insert_2d_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[99, 1, 2\], \[3, 4, 5\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[3, 4, 99\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[90, 91, 92\], \[3, 4, 5\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[90, 91, 92\]\]}}> : vector<2x3xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] +func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) { + %vcst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32> + %cst_scalar = arith.constant 99 : i32 + %cst_1d = arith.constant dense<[90, 91, 92]> : vector<3xi32> + %a = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32> + %b = vector.insert %cst_scalar, %vcst[1, 2] : i32 into vector<2x3xi32> + %c = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32> + %d = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32> + return %a, %b, %c, %d : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @insert_2d_splat_constant +// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[0, 99, 0\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[33, 33, 33\], \[0, 0, 0\]\]}}> : vector<2x3xi32> +// CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[33, 33, 33\]\]}}> : vector<2x3xi32> +// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]] +func.func @insert_2d_splat_constant() + -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) { + %vcst = arith.constant dense<0> : vector<2x3xi32> + %cst_zero = arith.constant 0 : i32 + %cst_scalar = arith.constant 99 : i32 + %cst_1d = arith.constant dense<33> : vector<3xi32> + %a = vector.insert %cst_zero, %vcst[0, 0] : i32 into vector<2x3xi32> + %b = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32> + %c = vector.insert %cst_scalar, %vcst[1, 1] : i32 into vector<2x3xi32> + %d = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32> + %e = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32> + return %a, %b, %c, %d, %e : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32> +} + +// ----- + // CHECK-LABEL: func @insert_element_fold // CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> // CHECK: return %[[V]]