diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -266,7 +266,9 @@ def Vector_ReductionOp : Vector_Op<"reduction", [NoSideEffect, PredOpTrait<"source operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>]>, + TCresVTEtIsSameAsOpBase<0, 0>>, + DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector, Optional:$acc)>, Results<(outs AnyType:$dest)> { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -484,6 +484,10 @@ return nullptr; } +Optional> ReductionOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -631,13 +631,60 @@ } }; +struct UnrollReductionPattern : public OpRewritePattern { + UnrollReductionPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options) + : OpRewritePattern(context, /*benefit=*/1), + options(options) {} + + LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, + PatternRewriter &rewriter) const override { + Optional> targetShape = + getTargetShape(options, reductionOp); + if (!targetShape) + return failure(); + SmallVector originalSize = *reductionOp.getShapeForUnroll(); + int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; + + // Create unrolled vector reduction. + Location loc = reductionOp.getLoc(); + Value accumulator = nullptr; + for (int64_t i = 0; i < ratio; ++i) { + SmallVector offsets = + getVectorOffset(originalSize, *targetShape, i); + SmallVector strides(offsets.size(), 1); + Value slicedOperand = rewriter.create( + loc, reductionOp.vector(), offsets, *targetShape, strides); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); + Value result = newOp->getResult(0); + + if (!accumulator) { + // This is the first reduction. + accumulator = result; + } else { + // On subsequent reduction, combine with the accumulator. + accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(), + accumulator, result); + } + } + + rewriter.replaceOp(reductionOp, accumulator); + return success(); + } + +private: + const vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options) { patterns.add(patterns.getContext(), options); + UnrollReductionPattern, UnrollMultiReductionPattern>( + patterns.getContext(), options); } void mlir::vector::populatePropagateVectorDistributionPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -106,3 +106,23 @@ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> + +// CHECK-LABEL: func @vector_reduction( +// CHECK-SAME: %[[v:.*]]: vector<8xf32> +// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2] +// CHECK: %[[r0:.*]] = vector.reduction , %[[s0]] +// CHECK: %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2] +// CHECK: %[[r1:.*]] = vector.reduction , %[[s1]] +// CHECK: %[[add1:.*]] = arith.addf %[[r0]], %[[r1]] +// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2] +// CHECK: %[[r2:.*]] = vector.reduction , %[[s2]] +// CHECK: %[[add2:.*]] = arith.addf %[[add1]], %[[r2]] +// CHECK: %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2] +// CHECK: %[[r3:.*]] = vector.reduction , %[[s3]] +// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]] +// CHECK: return %[[add3]] +func @vector_reduction(%v : vector<8xf32>) -> f32 { + %0 = vector.reduction , %v : vector<8xf32> into f32 + return %0 : f32 +} + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -268,6 +268,12 @@ return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{2}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn =