diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -309,6 +309,8 @@ ```mlir %1 = vector.multi_reduction "add", %0 [1, 3] : vector<4x8x16x32xf32> into vector<4x16xf32> + %2 = vector.multi_reduction "add", %1 [0, 1] : + vector<4x16xf32> into f32 ``` }]; let builders = [ @@ -322,8 +324,14 @@ VectorType getSourceVectorType() { return source().getType().cast(); } - VectorType getDestVectorType() { - return dest().getType().cast(); + Type getDestType() { + return dest().getType(); + } + + bool isReducedDim(int64_t d) { + assert(d >= 0 && d < static_cast(getReductionMask().size()) && + "d overflows the number of dims"); + return getReductionMask()[d]; } SmallVector getReductionMask() { @@ -341,18 +349,28 @@ } static SmallVector inferDestShape( - ArrayRef shape, ArrayRef reducedDimsMask) { - assert(shape.size() == reducedDimsMask.size() && - "shape and maks of different sizes"); + ArrayRef sourceShape, ArrayRef reducedDimsMask) { + assert(sourceShape.size() == reducedDimsMask.size() && + "sourceShape and maks of different sizes"); SmallVector res; - for (auto it : llvm::zip(reducedDimsMask, shape)) + for (auto it : llvm::zip(reducedDimsMask, sourceShape)) if (!std::get<0>(it)) res.push_back(std::get<1>(it)); return res; } + + static Type inferDestType( + ArrayRef sourceShape, ArrayRef reducedDimsMask, Type elementType) { + auto targetShape = inferDestShape(sourceShape, reducedDimsMask); + // TODO: update to also allow 0-d vectors when available. + if (targetShape.empty()) + return elementType; + return VectorType::get(targetShape, elementType); + } }]; let assemblyFormat = "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; + let hasFolder = 1; } def Vector_BroadcastOp : diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRVector VectorOps.cpp + VectorMultiDimReductionTransforms.cpp VectorTransferOpTransforms.cpp VectorTransforms.cpp VectorUtils.cpp diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -0,0 +1,430 @@ +//===- VectorMultiDimReductionTransforms.cpp - Multi-eduction Transforms --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites as 1->N patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" + +#define DEBUG_TYPE "vector-multi-reduction-transforms" + +using namespace mlir; + +/// This file implements the following transformations as composable atomic +/// patterns: +/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such +/// that all reduction dimensions are either innermost or outermost, by adding +/// the proper vector.transpose operations. +/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction form +/// rewrites nd vector.multi_reduction into 2d vector.multi_reduction, by +/// introducing vector.shape_cast ops to collapse + multi-reduce + expand back. +/// - TwoDimMultiReductionToElementWise: once in 2d vector.multi_reduction form, +/// with an +/// **outermost** reduction dimension, unroll the outer dimension to obtain a +/// sequence of 1-d vector ops. This also has an opportunity for tree-reduction +/// in the future. +/// - TwoDimMultiReductionToReduction: once in 2d vector.multi_reduction form, +/// with an **innermost** reduction dimension, unroll the outer dimension to +/// obtain a sequence of extract + vector.reduction + insert. This can further +/// lower to horizontal reduction ops. +/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-d vector +/// reduction (and are thus missing either a parallel or a reduction), we lift +/// them back up to 2-d with a simple vector.shape_cast to vector<1xk> so that +/// the other patterns can kick in, thus fully exiting out of the +/// vector.multi_reduction abstraction. +/// +/// Patterns are exposed by `populateVectorMultiReductionLoweringPatterns`. + +// Converts vector.multi_reduction into inner-most/outer-most reduction form +// by using vector.tranpose +class InnerOuterDimReductionConversion + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit InnerOuterDimReductionConversion(MLIRContext *context, + bool useInnerDimsForReduction) + : mlir::OpRewritePattern(context), + useInnerDimsForReduction(useInnerDimsForReduction) {} + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto src = multiReductionOp.source(); + auto loc = multiReductionOp.getLoc(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + // Separate reduction and parallel dims + auto reductionDimsRange = + multiReductionOp.reduction_dims().getAsValueRange(); + auto reductionDims = llvm::to_vector<4>(llvm::map_range( + reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); + llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), + reductionDims.end()); + int64_t reductionSize = reductionDims.size(); + SmallVector parallelDims; + for (int64_t i = 0; i < srcRank; i++) { + if (!reductionDimsSet.contains(i)) + parallelDims.push_back(i); + } + + // Add transpose only if inner-most/outer-most dimensions are not parallel + if (useInnerDimsForReduction && + (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); + + if (!useInnerDimsForReduction && + (parallelDims != + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); + + SmallVector indices; + if (useInnerDimsForReduction) { + indices.append(parallelDims.begin(), parallelDims.end()); + indices.append(reductionDims.begin(), reductionDims.end()); + } else { + indices.append(reductionDims.begin(), reductionDims.end()); + indices.append(parallelDims.begin(), parallelDims.end()); + } + auto transposeOp = rewriter.create(loc, src, indices); + SmallVector reductionMask(srcRank, false); + for (int i = 0; i < reductionSize; ++i) { + if (useInnerDimsForReduction) + reductionMask[srcRank - i - 1] = true; + else + reductionMask[i] = true; + } + rewriter.replaceOpWithNewOp( + multiReductionOp, transposeOp.result(), reductionMask, + multiReductionOp.kind()); + return success(); + } + +private: + const bool useInnerDimsForReduction; +}; + +// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction +// dimensions are either inner most or outer most. +class ReduceMultiDimReductionRank + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit ReduceMultiDimReductionRank(MLIRContext *context, + bool useInnerDimsForReduction) + : mlir::OpRewritePattern(context), + useInnerDimsForReduction(useInnerDimsForReduction) {} + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + auto srcShape = multiReductionOp.getSourceVectorType().getShape(); + auto loc = multiReductionOp.getLoc(); + + // If rank less than 2, nothing to do. + if (srcRank < 2) + return failure(); + + // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. + auto reductionMask = multiReductionOp.getReductionMask(); + if (srcRank == 2 && reductionMask.front() != reductionMask.back()) + return failure(); + + // 1. Separate reduction and parallel dims. + SmallVector parallelDims, parallelShapes; + SmallVector reductionDims, reductionShapes; + for (auto it : llvm::enumerate(reductionMask)) { + auto i = it.index(); + bool isReduction = it.value(); + if (isReduction) { + reductionDims.push_back(i); + reductionShapes.push_back(srcShape[i]); + } else { + parallelDims.push_back(i); + parallelShapes.push_back(srcShape[i]); + } + } + + // 2. Compute flattened parallel and reduction sizes. + int flattenedParallelDim = 0; + int flattenedReductionDim = 0; + if (parallelShapes.size() > 0) { + flattenedParallelDim = 1; + for (auto d : parallelShapes) + flattenedParallelDim *= d; + } + if (reductionShapes.size() > 0) { + flattenedReductionDim = 1; + for (auto d : reductionShapes) + flattenedReductionDim *= d; + } + // We must at least have some parallel or some reduction. + assert((flattenedParallelDim || flattenedReductionDim) && + "expected at least one parallel or reduction dim"); + + // 3. Fail if reduction/parallel dims are not contiguous. + if (useInnerDimsForReduction && + (parallelDims != + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); + if (!useInnerDimsForReduction && + (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); + + // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into + // a single parallel (resp. reduction) dim. + SmallVector mask; + SmallVector vectorShape; + if (flattenedParallelDim) { + mask.push_back(false); + vectorShape.push_back(flattenedParallelDim); + } + if (flattenedReductionDim) { + mask.push_back(true); + vectorShape.push_back(flattenedReductionDim); + } + if (!useInnerDimsForReduction && vectorShape.size() == 2) { + std::swap(mask.front(), mask.back()); + std::swap(vectorShape.front(), vectorShape.back()); + } + auto castedType = VectorType::get( + vectorShape, multiReductionOp.getSourceVectorType().getElementType()); + Value cast = rewriter.create( + loc, castedType, multiReductionOp.source()); + + // 5. Creates the flattened form of vector.multi_reduction with inner/outer + // most dim as reduction. + auto newOp = rewriter.create( + loc, cast, mask, multiReductionOp.kind()); + + // 6. If there are no parallel shapes, the result is a scalar. + // TODO: support 0-d vectors when available. + if (parallelShapes.empty()) { + rewriter.replaceOp(multiReductionOp, newOp.dest()); + return success(); + } + + // 7. Creates shape cast for the output 2d -> nd + VectorType outputCastedType = VectorType::get( + parallelShapes, + multiReductionOp.getSourceVectorType().getElementType()); + rewriter.replaceOpWithNewOp( + multiReductionOp, outputCastedType, newOp.dest()); + return success(); + } + +private: + const bool useInnerDimsForReduction; +}; + +// Unrolls vector.multi_reduction with outermost reductions +// and combines results +struct TwoDimMultiReductionToElementWise + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + // Rank-2 ["parallel", "reduce"] or bail. + if (srcRank != 2) + return failure(); + + if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) + return failure(); + + auto loc = multiReductionOp.getLoc(); + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + + Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); + if (!elementType.isIntOrIndexOrFloat()) + return failure(); + + Value condition; + Value result = + rewriter.create(loc, multiReductionOp.source(), 0) + .getResult(); + for (int64_t i = 1; i < srcShape[0]; i++) { + auto operand = + rewriter.create(loc, multiReductionOp.source(), i); + switch (multiReductionOp.kind()) { + case vector::CombiningKind::ADD: + if (elementType.isIntOrIndex()) + result = rewriter.create(loc, operand, result); + else + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MUL: + if (elementType.isIntOrIndex()) + result = rewriter.create(loc, operand, result); + else + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MINUI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MINSI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MINF: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXUI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXSI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXF: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::AND: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::OR: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::XOR: + result = rewriter.create(loc, operand, result); + break; + } + } + + rewriter.replaceOp(multiReductionOp, result); + return success(); + } +}; + +// Converts 2d vector.multi_reduction with inner most reduction dimension into a +// sequence of vector.reduction ops. +struct TwoDimMultiReductionToReduction + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + if (srcRank != 2) + return failure(); + + if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) + return failure(); + + auto loc = multiReductionOp.getLoc(); + Value result = rewriter.create( + loc, multiReductionOp.getDestType(), + rewriter.getZeroAttr(multiReductionOp.getDestType())); + int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; + + // TODO: Add vector::CombiningKind attribute instead of string to + // vector.reduction. + auto getKindStr = [](vector::CombiningKind kind) { + switch (kind) { + case vector::CombiningKind::ADD: + return "add"; + case vector::CombiningKind::MUL: + return "mul"; + case vector::CombiningKind::MINUI: + return "minui"; + case vector::CombiningKind::MINSI: + return "minsi"; + case vector::CombiningKind::MINF: + return "minf"; + case vector::CombiningKind::MAXUI: + return "maxui"; + case vector::CombiningKind::MAXSI: + return "maxsi"; + case vector::CombiningKind::MAXF: + return "maxf"; + case vector::CombiningKind::AND: + return "and"; + case vector::CombiningKind::OR: + return "or"; + case vector::CombiningKind::XOR: + return "xor"; + } + llvm_unreachable("unknown combining kind"); + }; + + for (int i = 0; i < outerDim; ++i) { + auto v = rewriter.create( + loc, multiReductionOp.source(), ArrayRef{i}); + auto reducedValue = rewriter.create( + loc, getElementTypeOrSelf(multiReductionOp.getDestType()), + rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v, + ValueRange{}); + result = rewriter.create(loc, reducedValue, + result, i); + } + rewriter.replaceOp(multiReductionOp, result); + return success(); + } +}; + +// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d +// form with both a single parallel and reduction dimension. +// This is achieved with a simple vector.shape_cast that inserts a leading 1. +// The case with a single parallel dimension is a noop and folds away +// separately. +struct OneDimMultiReductionToTwoDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + // Rank-1 or bail. + if (srcRank != 1) + return failure(); + + auto loc = multiReductionOp.getLoc(); + auto srcVectorType = multiReductionOp.getSourceVectorType(); + auto srcShape = srcVectorType.getShape(); + auto castedType = VectorType::get(ArrayRef{1, srcShape.back()}, + srcVectorType.getElementType()); + assert(!multiReductionOp.getDestType().isa() && + "multi_reduction with a single dimension expects a scalar result"); + + // If the unique dim is reduced and we insert a parallel in front, we need a + // {false, true} mask. + SmallVector mask{false, true}; + + /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) + Value cast = rewriter.create( + loc, castedType, multiReductionOp.source()); + Value reduced = rewriter.create( + loc, cast, mask, multiReductionOp.kind()); + rewriter.replaceOpWithNewOp(multiReductionOp, reduced, + ArrayRef{0}); + return success(); + } +}; + +void mlir::vector::populateVectorMultiReductionLoweringPatterns( + RewritePatternSet &patterns, bool useInnerDimsForReduction) { + patterns.add(patterns.getContext(), + useInnerDimsForReduction); + if (useInnerDimsForReduction) + patterns.add(patterns.getContext()); + else + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -260,11 +260,10 @@ CombiningKind kind) { result.addOperands(source); auto sourceVectorType = source.getType().cast(); - auto targetShape = MultiDimReductionOp::inferDestShape( - sourceVectorType.getShape(), reductionMask); - auto targetVectorType = - VectorType::get(targetShape, sourceVectorType.getElementType()); - result.addTypes(targetVectorType); + auto targetType = MultiDimReductionOp::inferDestType( + sourceVectorType.getShape(), reductionMask, + sourceVectorType.getElementType()); + result.addTypes(targetType); SmallVector reductionDims; for (auto en : llvm::enumerate(reductionMask)) @@ -278,17 +277,23 @@ static LogicalResult verify(MultiDimReductionOp op) { auto reductionMask = op.getReductionMask(); - auto targetShape = MultiDimReductionOp::inferDestShape( - op.getSourceVectorType().getShape(), reductionMask); - auto targetVectorType = - VectorType::get(targetShape, op.getSourceVectorType().getElementType()); - if (targetVectorType != op.getDestVectorType()) + auto targetType = MultiDimReductionOp::inferDestType( + op.getSourceVectorType().getShape(), reductionMask, + op.getSourceVectorType().getElementType()); + // TODO: update to support 0-d vectors when available. + if (targetType != op.getDestType()) return op.emitError("invalid output vector type: ") - << op.getDestVectorType() << " (expected: " << targetVectorType - << ")"; + << op.getDestType() << " (expected: " << targetType << ")"; return success(); } +OpFoldResult MultiDimReductionOp::fold(ArrayRef operands) { + // Single parallel dim, this is a noop. + if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -875,14 +875,14 @@ case CombiningKind::MAXF: combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::ADD: // Already handled this special case above. - case CombiningKind::AND: // Only valid for integer types. + case CombiningKind::ADD: // Already handled this special case above. + case CombiningKind::AND: // Only valid for integer types. case CombiningKind::MINUI: // Only valid for integer types. case CombiningKind::MINSI: // Only valid for integer types. case CombiningKind::MAXUI: // Only valid for integer types. case CombiningKind::MAXSI: // Only valid for integer types. - case CombiningKind::OR: // Only valid for integer types. - case CombiningKind::XOR: // Only valid for integer types. + case CombiningKind::OR: // Only valid for integer types. + case CombiningKind::XOR: // Only valid for integer types. return Optional(); } return Optional(combinedResult); @@ -3504,315 +3504,6 @@ const bool enableIndexOptimizations; }; -// Converts vector.multi_reduction into inner-most/outer-most reduction form -// by using vector.tranpose -class InnerOuterDimReductionConversion - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - explicit InnerOuterDimReductionConversion(MLIRContext *context, - bool useInnerDimsForReduction) - : mlir::OpRewritePattern(context), - useInnerDimsForReduction(useInnerDimsForReduction) {} - - LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, - PatternRewriter &rewriter) const override { - auto src = multiReductionOp.source(); - auto loc = multiReductionOp.getLoc(); - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - - // Separate reduction and parallel dims - auto reductionDimsRange = - multiReductionOp.reduction_dims().getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( - reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); - llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), - reductionDims.end()); - int64_t reductionSize = reductionDims.size(); - SmallVector parallelDims; - for (int64_t i = 0; i < srcRank; i++) { - if (!reductionDimsSet.contains(i)) - parallelDims.push_back(i); - } - - // Add transpose only if inner-most/outer-most dimensions are not parallel - if (useInnerDimsForReduction && - (parallelDims == - llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) - return failure(); - - if (!useInnerDimsForReduction && - (parallelDims != - llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) - return failure(); - - SmallVector indices; - if (useInnerDimsForReduction) { - indices.append(parallelDims.begin(), parallelDims.end()); - indices.append(reductionDims.begin(), reductionDims.end()); - } else { - indices.append(reductionDims.begin(), reductionDims.end()); - indices.append(parallelDims.begin(), parallelDims.end()); - } - auto transposeOp = rewriter.create(loc, src, indices); - SmallVector reductionMask(srcRank, false); - for (int i = 0; i < reductionSize; ++i) { - if (useInnerDimsForReduction) - reductionMask[srcRank - i - 1] = true; - else - reductionMask[i] = true; - } - rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.result(), reductionMask, - multiReductionOp.kind()); - return success(); - } - -private: - const bool useInnerDimsForReduction; -}; - -// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction -// dimensions are either inner most or outer most. -class ReduceMultiDimReductionRank - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - explicit ReduceMultiDimReductionRank(MLIRContext *context, - bool useInnerDimsForReduction) - : mlir::OpRewritePattern(context), - useInnerDimsForReduction(useInnerDimsForReduction) {} - - LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, - PatternRewriter &rewriter) const override { - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - auto srcShape = multiReductionOp.getSourceVectorType().getShape(); - auto loc = multiReductionOp.getLoc(); - if (srcRank == 2) - return failure(); - - // Separate reduction and parallel dims - auto reductionDimsRange = - multiReductionOp.reduction_dims().getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( - reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); - llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), - reductionDims.end()); - SmallVector parallelDims, parallelShapes; - int canonicalReductionDim = 1; - int canonicalParallelDim = 1; - for (int64_t i = 0; i < srcRank; i++) { - if (!reductionDimsSet.contains(i)) { - parallelDims.push_back(i); - parallelShapes.push_back(srcShape[i]); - canonicalParallelDim *= srcShape[i]; - } else { - canonicalReductionDim *= srcShape[i]; - } - } - - // Fail if reduction dims are not either inner-most or outer-most - if (useInnerDimsForReduction && - (parallelDims != - llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) - return failure(); - - if (!useInnerDimsForReduction && - (parallelDims == - llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) - return failure(); - - // Creates shape cast for the inputs n_d -> 2d - int64_t outerDim = - useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim; - int64_t innerDim = - useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim; - - auto castedType = VectorType::get( - ArrayRef{outerDim, innerDim}, - multiReductionOp.getSourceVectorType().getElementType()); - auto castedOp = rewriter.create( - loc, castedType, multiReductionOp.source()); - - // Creates the canonical form of 2d vector.multi_reduction with inner/outer - // most dim as reduction. - SmallVector mask{!useInnerDimsForReduction, - useInnerDimsForReduction}; - auto newOp = rewriter.create( - loc, castedOp.result(), mask, multiReductionOp.kind()); - - // Creates shape cast for the output 2d -> nd - VectorType outputCastedType = VectorType::get( - parallelShapes, - multiReductionOp.getSourceVectorType().getElementType()); - Value castedOutputOp = rewriter.create( - loc, outputCastedType, newOp.dest()); - - rewriter.replaceOp(multiReductionOp, castedOutputOp); - return success(); - } - -private: - const bool useInnerDimsForReduction; -}; - -// Unrolls vector.multi_reduction with outermost reductions -// and combines results -struct UnrollOuterMultiReduction - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, - PatternRewriter &rewriter) const override { - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - if (srcRank != 2) - return failure(); - - if (multiReductionOp.getReductionMask()[1] || - !multiReductionOp.getReductionMask()[0]) - return failure(); - - auto loc = multiReductionOp.getLoc(); - ArrayRef srcShape = - multiReductionOp.getSourceVectorType().getShape(); - - Type elementType = multiReductionOp.getDestVectorType().getElementType(); - if (!elementType.isIntOrIndexOrFloat()) - return failure(); - - Value condition; - Value result = - rewriter.create(loc, multiReductionOp.source(), 0) - .getResult(); - for (int64_t i = 1; i < srcShape[0]; i++) { - auto operand = - rewriter.create(loc, multiReductionOp.source(), i); - switch (multiReductionOp.kind()) { - case vector::CombiningKind::ADD: - if (elementType.isIntOrIndex()) - result = rewriter.create(loc, operand, result); - else - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MUL: - if (elementType.isIntOrIndex()) - result = rewriter.create(loc, operand, result); - else - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINUI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINSI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINF: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXUI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXSI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXF: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::AND: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::OR: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::XOR: - result = rewriter.create(loc, operand, result); - break; - } - } - - rewriter.replaceOp(multiReductionOp, result); - return success(); - } -}; - -// Converts 2d vector.multi_reduction with inner most reduction dimension into a -// sequence of vector.reduction ops. -struct TwoDimMultiReductionToReduction - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, - PatternRewriter &rewriter) const override { - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - if (srcRank != 2) - return failure(); - - if (multiReductionOp.getReductionMask()[0] || - !multiReductionOp.getReductionMask()[1]) - return failure(); - - auto loc = multiReductionOp.getLoc(); - - Value result = - multiReductionOp.getDestVectorType().getElementType().isIntOrIndex() - ? rewriter.create( - loc, multiReductionOp.getDestVectorType(), - DenseElementsAttr::get(multiReductionOp.getDestVectorType(), - 0)) - : rewriter.create( - loc, multiReductionOp.getDestVectorType(), - DenseElementsAttr::get(multiReductionOp.getDestVectorType(), - 0.0f)); - - int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; - - // TODO: Add vector::CombiningKind attribute instead of string to - // vector.reduction. - auto getKindStr = [](vector::CombiningKind kind) { - switch (kind) { - case vector::CombiningKind::ADD: - return "add"; - case vector::CombiningKind::MUL: - return "mul"; - case vector::CombiningKind::MINUI: - return "minui"; - case vector::CombiningKind::MINSI: - return "minsi"; - case vector::CombiningKind::MINF: - return "minf"; - case vector::CombiningKind::MAXUI: - return "maxui"; - case vector::CombiningKind::MAXSI: - return "maxsi"; - case vector::CombiningKind::MAXF: - return "maxf"; - case vector::CombiningKind::AND: - return "and"; - case vector::CombiningKind::OR: - return "or"; - case vector::CombiningKind::XOR: - return "xor"; - } - llvm_unreachable("unknown combining kind"); - }; - - for (int i = 0; i < outerDim; ++i) { - auto v = rewriter.create( - loc, multiReductionOp.source(), ArrayRef{i}); - auto reducedValue = rewriter.create( - loc, multiReductionOp.getDestVectorType().getElementType(), - rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v, - ValueRange{}); - result = rewriter.create(loc, reducedValue, - result, i); - } - rewriter.replaceOp(multiReductionOp, result); - return success(); - } -}; - void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool enableIndexOptimizations) { patterns.add(patterns.getContext()); } -void mlir::vector::populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, bool useInnerDimsForReduction) { - patterns.add( - patterns.getContext(), useInnerDimsForReduction); - if (useInnerDimsForReduction) - patterns.add(patterns.getContext()); - else - patterns.add(patterns.getContext()); -} - void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options) { patterns.add into tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @vector_multi_reduction_single_parallel( +// CHECK-SAME: %[[v:.*]]: vector<2xf32> +func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [] : vector<2xf32> to vector<2xf32> + +// CHECK: return %[[v]] : vector<2xf32> + return %0 : vector<2xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -621,3 +621,11 @@ return %r, %r2 : vector<32xf32>, vector<16x32xf32> } +// CHECK-LABEL: @multi_reduction +func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 { + %1 = vector.multi_reduction #vector.kind, %0 [1, 3] : + vector<4x8x16x32xf32> to vector<4x16xf32> + %2 = vector.multi_reduction #vector.kind, %1 [0, 1] : + vector<4x16xf32> to f32 + return %2 : f32 +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -17,6 +17,18 @@ // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] +func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 { + %0 = vector.multi_reduction #vector.kind, %arg0 [0, 1] : vector<2x4xf32> to f32 + return %0 : f32 +} +// CHECK-LABEL: func @vector_multi_reduction_to_scalar +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32> +// CHECK: %[[REDUCED:.*]] = vector.reduction "mul", %[[CASTED]] : vector<8xf32> into f32 +// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32> +// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32> +// CHECK: return %[[RES]] + func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> @@ -50,7 +62,7 @@ // CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> -// CHECK: return %[[RESULT]] +// CHECK: return %[[RESULT]] func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> { @@ -63,7 +75,7 @@ // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> // CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> -// CHECK: return %[[RESULT]] +// CHECK: return %[[RESULT]] func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>