Index: mlir/include/mlir/Dialect/Vector/IR/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -314,7 +314,9 @@ Vector_Op<"multi_reduction", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$source, I64ArrayAttr:$reduction_dims)>, Index: mlir/lib/Dialect/Vector/IR/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -371,6 +371,10 @@ return {}; } +Optional> MultiDimReductionOp::getShapeForUnroll() { + return llvm::to_vector<4>(getSourceVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -207,23 +207,23 @@ vector::UnrollVectorOptions options; }; -struct UnrollContractionPattern - : public OpRewritePattern { - struct OffsetMapInfo { - static SmallVector getEmptyKey() { return {int64_t(-1)}; } +struct OffsetMapInfo { + static SmallVector getEmptyKey() { return {int64_t(-1)}; } - static SmallVector getTombstoneKey() { return {int64_t(-2)}; } + static SmallVector getTombstoneKey() { return {int64_t(-2)}; } - static unsigned getHashValue(const SmallVector &v) { - return static_cast( - llvm::hash_combine_range(v.begin(), v.end())); - } + static unsigned getHashValue(const SmallVector &v) { + return static_cast(llvm::hash_combine_range(v.begin(), v.end())); + } - static bool isEqual(const SmallVector &lhs, - const SmallVector &rhs) { - return lhs == rhs; - } - }; + static bool isEqual(const SmallVector &lhs, + const SmallVector &rhs) { + return lhs == rhs; + } +}; + +struct UnrollContractionPattern + : public OpRewritePattern { UnrollContractionPattern(MLIRContext *context, const vector::UnrollVectorOptions &options) : OpRewritePattern(context, /*benefit=*/1), @@ -320,6 +320,74 @@ vector::UnrollVectorOptions options; }; +struct UnrollMultiReductionPattern + : public OpRewritePattern { + UnrollMultiReductionPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, + PatternRewriter &rewriter) const override { + static Optional> targetShape = + getTargetShape(options, reductionOp); + if (!targetShape) + return failure(); + SmallVector originalSize = *reductionOp.getShapeForUnroll(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + llvm::MapVector< + SmallVector, Value, + llvm::DenseMap, unsigned, OffsetMapInfo>> + accCache; + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); + Location loc = reductionOp.getLoc(); + for (int64_t i = 0; i < sliceCount; i++) { + SmallVector offsets = + getVectorOffset(originalSize, *targetShape, i); + + SmallVector operandStrides(offsets.size(), 1); + Value slicedOperand = rewriter.create( + loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); + + SmallVector dstShape; + SmallVector destOffset; + for (size_t i : llvm::seq(size_t(0), targetShape->size())) { + if (!reductionOp.isReducedDim(i)) { + destOffset.push_back(offsets[i]); + dstShape.push_back((*targetShape)[i]); + } + } + auto targetType = VectorType::get( + dstShape, reductionOp.getSourceVectorType().getElementType()); + Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, + slicedOperand, targetType); + Value result = newOp->getResult(0); + // Save the accumulated value until all the loops are unrolled since + // reduction loop keeps updating the accumulator. + auto accIt = accCache.find(destOffset); + if (accIt != accCache.end()) + result = makeArithReduction(rewriter, loc, reductionOp.kind(), result, + accIt->second); + accCache[destOffset] = result; + } + // Assemble back the accumulator into a single vector. + Value result = rewriter.create( + loc, reductionOp.getDestType(), + rewriter.getZeroAttr(reductionOp.getDestType())); + for (const auto &it : accCache) { + SmallVector dstStrides(it.first.size(), 1); + result = rewriter.create( + loc, it.second, result, it.first, dstStrides); + } + rewriter.replaceOp(reductionOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + struct UnrollElementwisePattern : public RewritePattern { UnrollElementwisePattern(MLIRContext *context, const vector::UnrollVectorOptions &options) @@ -568,8 +636,8 @@ void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options) { patterns.add( - patterns.getContext(), options); + UnrollContractionPattern, UnrollElementwisePattern, + UnrollMultiReductionPattern>(patterns.getContext(), options); } void mlir::vector::populatePropagateVectorDistributionPatterns( Index: mlir/test/Dialect/Vector/vector-unroll-options.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -80,3 +80,29 @@ } // CHECK-LABEL: func @vector_fma // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> + +func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> { + %0 = vector.multi_reduction #vector.kind, %v [1] : vector<4x6xf32> to vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: func @vector_multi_reduction +// CHECK: %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R0:.*]] = vector.multi_reduction , %[[E0]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R1:.*]] = vector.multi_reduction , %[[E1]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32> +// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R2:.*]] = vector.multi_reduction , %5 [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32> +// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R3:.*]] = vector.multi_reduction , %[[E3]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R4:.*]] = vector.multi_reduction , %[[E4]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32> +// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> +// CHECK: %[[R5:.*]] = vector.multi_reduction , %[[E5]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32> +// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: return %[[V2]] : vector<4xf32> Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -265,7 +265,8 @@ patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{2, 2}) .setFilterConstraint([](Operation *op) { - return success(isa(op)); + return success(isa(op)); })); if (unrollBasedOnType) {