Index: mlir/include/mlir/Analysis/Utils.h =================================================================== --- mlir/include/mlir/Analysis/Utils.h +++ mlir/include/mlir/Analysis/Utils.h @@ -354,8 +354,11 @@ Optional getMemoryFootprintBytes(AffineForOp forOp, int memorySpace = -1); -/// Returns true if `forOp' is a parallel loop. -bool isLoopParallel(AffineForOp forOp); +/// Returns true if `forOp' is a parallel loop. By default loops with +/// loop-carried variables (iter_args) are considered non-parallel, unless +/// `ignoreIterArgs = true` is provided, in which case only memory operations +/// are checked. See also `mlir::isParallelReductionLoop`. +bool isLoopParallel(AffineForOp forOp, bool ignoreIterArgs = false); /// Simplify the integer set by simplifying the underlying affine expressions by /// flattening and some simple inference. Also, drop any duplicate constraints. Index: mlir/include/mlir/Dialect/Affine/Passes.td =================================================================== --- mlir/include/mlir/Dialect/Affine/Passes.td +++ mlir/include/mlir/Dialect/Affine/Passes.td @@ -112,7 +112,11 @@ "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory " "dimensions to match. See defaultPatterns in Vectorize.cpp for " "a description and examples. This is used for testing purposes", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + Option<"vectorizeReductions", "vectorize-reductions", "bool", + /*default=*/"false", + "Vectorize known reductions expressed via iter_args. " + "Switched off by default."> ]; } Index: mlir/include/mlir/Dialect/Affine/Utils.h =================================================================== --- mlir/include/mlir/Dialect/Affine/Utils.h +++ mlir/include/mlir/Dialect/Affine/Utils.h @@ -13,7 +13,10 @@ #ifndef MLIR_DIALECT_AFFINE_UTILS_H #define MLIR_DIALECT_AFFINE_UTILS_H +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -25,6 +28,7 @@ class AffineParallelOp; struct LogicalResult; class Operation; +class ReductionRecognizer; /// Replaces parallel affine.for op with 1-d affine.parallel op. /// mlir::isLoopParallel detect the parallel affine.for ops. @@ -76,16 +80,23 @@ // The candidate will be vectorized using the vectorization factor in // 'vectorSizes' for that dimension. DenseMap loopToVectorDim; + // An optional reduction recognizer that will be used to recognize reduction + // loops vectorizable along the reduction dimension. + const ReductionRecognizer *reductionRecognizer = nullptr; }; /// Vectorizes affine loops in 'loops' using the n-D vectorization factors in /// 'vectorSizes'. By default, each vectorization factor is applied /// inner-to-outer to the loops of each loop nest. 'fastestVaryingPattern' can /// be optionally used to provide a different loop vectorization order. +/// If `reductionRecognizer` is not null, recognized reduction loops may be +/// vectorized along the reduction dimension. +/// TODO: Vectorizing reductions is supported only for 1-D vectorization. void vectorizeAffineLoops( Operation *parentOp, llvm::DenseSet> &loops, - ArrayRef vectorSizes, ArrayRef fastestVaryingPattern); + ArrayRef vectorSizes, ArrayRef fastestVaryingPattern, + const ReductionRecognizer *reductionRecognizer = nullptr); /// External utility to vectorize affine loops from a single loop nest using an /// n-D vectorization strategy (see doc in VectorizationStrategy definition). Index: mlir/include/mlir/Transforms/ReductionUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Transforms/ReductionUtils.h @@ -0,0 +1,121 @@ +//===- ReductionUtils.h - Reduction-related utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares utilities for recognizing and manipulating +// reductions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_REDUCTIONUTILS_H +#define MLIR_TRANSFORMS_REDUCTIONUTILS_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { + +class OpBuilder; + +/// This abstract class is a collection of utilities needed for transforming +/// (vectorizing etc) a specific kind of reduction (like sum, product, etc). +/// TODO: Support more utilities like generating `affine.parallel` and +/// `atomic_rmw`. +class ReductionInfo { +public: + /// Creates an operation that takes the vector value `vector` and reduces it + /// into a scalar, for example: + /// %res = vector.reduction "add", %vector : vector<128xf32> into f32 + virtual Value createVectorReduction(Value vector, + OpBuilder &builder) const = 0; + + /// Creates an operation that combines scalar values `lhs` and `rhs`, e.g.: + /// %res = addf %lhs, %rhs : f32 + virtual Value combine(Value lhs, Value rhs, OpBuilder &builder) const = 0; + + /// Creates an attribute of type `elemType` representing the neutral element + /// of this reduction. + virtual Attribute getNeutralElementAttr(Type elemType, + OpBuilder &builder) const = 0; + + virtual ~ReductionInfo() = default; +}; + +/// The base class for reduction recognizers that check if the given +/// loop-carried variable represents a known reduction kind. +class ReductionRecognizer { +public: + /// Checks if the loop-carried variable represented by the argument passed to + /// the current iteration `arg` and the value passed to the next interation + /// `yielded` computes a known reduction. Returns an instance of + /// `ReductionInfo` for this reduction kind if the reduction is recognized and + /// `nullptr` if it's not. + virtual const ReductionInfo *recognize(BlockArgument arg, + Value yielded) const = 0; + virtual ~ReductionRecognizer() = default; +}; + +/// A recognizer that rejects everything. +class NullReductionRecognizer : public ReductionRecognizer { +public: + virtual const ReductionInfo *recognize(BlockArgument arg, + Value yielded) const override { + return nullptr; + } +}; + +/// An implementation of `ReductionInfo` for standard reductions implementable +/// with a single operation `Op`. +template +class StandardReductionInfo : public ReductionInfo { +public: + virtual Value createVectorReduction(Value vector, + OpBuilder &builder) const override; + virtual Value combine(Value lhs, Value rhs, + OpBuilder &builder) const override; + virtual Attribute getNeutralElementAttr(Type elemType, + OpBuilder &builder) const override; + + /// Returns an instance of this class. + static ReductionInfo *get(); + + /// The string used as the kind for `vector.reduction`. + static const char *const kindString; +}; + +/// A reduction recognizer that recognizes standard parallelizable reductions. +/// Currently supports addf, mulf, addi, muli. +/// TODO: Support max and min. +class StandardReductionRecognizer : public ReductionRecognizer { +public: + virtual const ReductionInfo *recognize(BlockArgument arg, + Value yielded) const override; + +protected: + Operation *getSingleOpCombiner(BlockArgument arg, Value yielded) const; + virtual const ReductionInfo *recognizeSingleOpCombiner(BlockArgument arg, + Value yielded) const; +}; + +/// Returns true if `forOp` is a parallel loop possibly implementing known +/// reductions via loop-carried variables (iter_args). Reductions are considered +/// known (and parallel) if they are recognized by `reductionRecognizer`. +bool isParallelReductionLoop(AffineForOp forOp, + const ReductionRecognizer &reductionRecognizer); + +/// Populates `reductions` with the information about known reductions +/// implemented by `forOp`. Reductions are considered known if they are +/// recognized by `reductionRecognizer`. Returns `true` if all iteration +/// variables implement recognizable reductions and `false` otherwise. +bool getKnownReductions(AffineForOp forOp, + const ReductionRecognizer &reductionRecognizer, + SmallVectorImpl &reductions); + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_REDUCTIONUTILS_H Index: mlir/lib/Analysis/Utils.cpp =================================================================== --- mlir/lib/Analysis/Utils.cpp +++ mlir/lib/Analysis/Utils.cpp @@ -1268,12 +1268,14 @@ }); } -/// Returns true if 'forOp' is parallel. -bool mlir::isLoopParallel(AffineForOp forOp) { - // Loop is not parallel if it has SSA loop-carried dependences. - // TODO: Conditionally support reductions and other loop-carried dependences - // that could be handled in the context of a parallel loop. - if (forOp.getNumIterOperands() > 0) +/// Returns true if `forOp' is a parallel loop. By default loops with +/// loop-carried variables (iter_args) are considered non-parallel, unless +/// `ignoreIterArgs = true` is provided, in which case only memory operations +/// are checked. See also `mlir::isParallelReductionLoop`. +bool mlir::isLoopParallel(AffineForOp forOp, bool ignoreIterArgs) { + // Loop is not parallel if it has SSA loop-carried dependences (unless it is + // explicitly requested to ignore them). + if (!ignoreIterArgs && forOp.getNumIterOperands() > 0) return false; // Collect all load and store ops in loop nest rooted at 'forOp'. Index: mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp =================================================================== --- mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -20,8 +20,10 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "llvm/Support/Debug.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/ReductionUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace vector; @@ -324,7 +326,7 @@ /// Unsupported cases, extensions, and work in progress (help welcome :-) ): /// ======================================================================== /// 1. lowering to concrete vector types for various HW; -/// 2. reduction support; +/// 2. reduction support for n-D vectorization and non-unit steps; /// 3. non-effecting padding during vector.transfer_read and filter during /// vector.transfer_write; /// 4. misalignment support vector.transfer_read / vector.transfer_write @@ -487,6 +489,73 @@ /// /// Of course, much more intricate n-D imperfectly-nested patterns can be /// vectorized too and specified in a fully declarative fashion. +/// +/// Reduction: +/// ========== +/// Vectorizing reduction loops along the reduction dimension is supported if: +/// - the reduction is recognizable (see `ReductionRecognizer`), +/// - the vectorization is 1-D, and +/// - the step size of the loop equals to one. +/// +/// Comparing to the non-vector-dimension case, two additional things are done +/// during vectorization of such loops: +/// - The resulting vector returned from the loop is reduced to a scalar using +/// `vector.reduce`. +/// - In some cases a mask is applied to the vector yielded at the end of the +/// loop to prevent garbage values from being written to the accumulator. +/// +/// Reduction vectorization is switched off by default, it can be enabled by +/// passing a reduction recognizer to utility functions, or by passing +/// `vectorize-reductions=true` to the vectorization pass. +/// +/// Consider the following example: +/// ```mlir +/// func @vecred(%in: memref<512xf32>) -> f32 { +/// %cst = constant 0.000000e+00 : f32 +/// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { +/// %ld = affine.load %in[%i] : memref<512xf32> +/// %cos = math.cos %ld : f32 +/// %add = addf %part_sum, %cos : f32 +/// affine.yield %add : f32 +/// } +/// return %sum : f32 +/// } +/// ``` +/// +/// The -affine-vectorize pass with the following arguments: +/// ``` +/// -affine-vectorize="virtual-vector-size=128 test-fastest-varying=0 \ +/// vectorize-reductions=true" +/// ``` +/// produces the following output: +/// ```mlir +/// #map = affine_map<(d0) -> (-d0 + 500)> +/// func @vecred(%arg0: memref<512xf32>) -> f32 { +/// %cst = constant 0.000000e+00 : f32 +/// %cst_0 = constant dense<0.000000e+00> : vector<128xf32> +/// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) +/// -> (vector<128xf32>) { +/// // %2 is the number of iterations left in the original loop. +/// %2 = affine.apply #map(%arg1) +/// %3 = vector.create_mask %2 : vector<128xi1> +/// %cst_1 = constant 0.000000e+00 : f32 +/// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : +/// memref<512xf32>, vector<128xf32> +/// %5 = math.cos %4 : vector<128xf32> +/// %6 = addf %arg2, %5 : vector<128xf32> +/// // We filter out the effect of last 12 elements using the mask. +/// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> +/// affine.yield %7 : vector<128xf32> +/// } +/// %1 = vector.reduction "add", %0 : vector<128xf32> into f32 +/// return %1 : f32 +/// } +/// ``` +/// +/// Note that because of loop misalignment we needed to apply a mask to prevent +/// last 12 elements from affecting the final result. The mask is full of ones +/// in every iteration except for the last one, in which it has the form +/// `11...100...0` with 116 ones and 12 zeros. #define DEBUG_TYPE "early-vect" @@ -644,6 +713,17 @@ void registerValueScalarReplacement(BlockArgument replaced, BlockArgument replacement); + /// Registers the scalar replacement of a scalar result returned from a + /// reduction loop. 'replacement' must be scalar. + /// + /// This utility is used to register the replacement for scalar results of + /// vectorized reduction loops with iter_args. + /// + /// Example 2: + /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) + /// * 'replacement': %1 = vector.reduction "add" %0 : vector<4xf32> into f32 + void registerLoopResultScalarReplacement(Value replaced, Value replacement); + /// Returns in 'replacedVals' the scalar replacement for values in /// 'inputVals'. void getScalarValueReplacementsFor(ValueRange inputVals, @@ -663,10 +743,16 @@ // Maps input scalar values to their new scalar counterparts in the vector // loop nest. BlockAndValueMapping valueScalarReplacement; + // Maps results of reduction loops to their new scalar counterparts. + DenseMap loopResultScalarReplacement; // Maps the newly created vector loops to their vector dimension. DenseMap vecLoopToVecDim; + // Maps the new vectorized loops to the corresponding vector masks if it is + // required. + DenseMap vecLoopToMask; + // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; @@ -761,6 +847,26 @@ registerValueScalarReplacementImpl(replaced, replacement); } +/// Registers the scalar replacement of a scalar result returned from a +/// reduction loop. 'replacement' must be scalar. +/// +/// This utility is used to register the replacement for scalar results of +/// vectorized reduction loops with iter_args. +/// +/// Example 2: +/// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) +/// * 'replacement': %1 = vector.reduction "add" %0 : vector<4xf32> into f32 +void VectorizationState::registerLoopResultScalarReplacement( + Value replaced, Value replacement) { + assert(isa(replaced.getDefiningOp())); + assert(loopResultScalarReplacement.count(replaced) == 0 && + "already registered"); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop " + "with scalar: " + << replacement); + loopResultScalarReplacement[replaced] = replacement; +} + void VectorizationState::registerValueScalarReplacementImpl(Value replaced, Value replacement) { assert(!valueScalarReplacement.contains(replaced) && @@ -850,6 +956,101 @@ return newConstOp; } +/// Creates a constant vector filled with the neutral elements of the given +/// reduction. The scalar type of vector elements will be taken from +/// `oldOperand`. +static ConstantOp createInitialVector(const ReductionInfo *reduction, + Value oldOperand, + VectorizationState &state) { + Type scalarTy = oldOperand.getType(); + if (!VectorType::isValidElementType(scalarTy)) + return nullptr; + + Attribute valueAttr = + reduction->getNeutralElementAttr(scalarTy, state.builder); + auto vecTy = getVectorType(scalarTy, state.strategy); + auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); + auto newConstOp = + state.builder.create(oldOperand.getLoc(), vecAttr); + + return newConstOp; +} + +/// Creates a mask used to filter out garbage elements in the last iteration +/// of unaligned loops. If a mask is not required then `nullptr` is returned. +/// The mask will be a vector of booleans representing meaningful vector +/// elements in the current iteration. It is filled with ones for each iteration +/// except for the last one, where it has the form `11...100...0` with the +/// number of ones equal to the number of meaningful elements (i.e. the number +/// of iterations that would be left in the original loop). +static Value createMask(AffineForOp vecForOp, VectorizationState &state) { + assert(state.strategy->vectorSizes.size() == 1 && + "Creating a mask non-1-D vectors is not supported."); + assert(vecForOp.getStep() == state.strategy->vectorSizes[0] && + "Creating a mask for loops with non-unit original step size is not " + "supported."); + + // Check if we have already created the mask. + if (Value mask = state.vecLoopToMask.lookup(vecForOp)) + return mask; + + // If the loop has constant bounds and the original number of iterations is + // divisable by the vector size then we don't need a mask. + if (vecForOp.hasConstantBounds()) { + int64_t originalTripCount = + vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound(); + if (originalTripCount % vecForOp.getStep() == 0) + return nullptr; + } + + OpBuilder::InsertionGuard guard(state.builder); + state.builder.setInsertionPointToStart(vecForOp.getBody()); + + // We generate the mask using the `vector.create_mask` operation which accepts + // the number of meaningful elements (i.e. the legth of the prefix of 1s). + // To compute the number of meaningful elements we subtract the current value + // of the iteration variable from the upper bound of the loop. Example: + // + // // 500 is the upper bound of the loop + // #map = affine_map<(d0) -> (500 - d0)> + // %elems_left = affine.apply #map(%iv) + // %mask = vector.create_mask %elems_left : vector<128xi1> + + Location loc = vecForOp.getLoc(); + + // First we get the upper bound of the loop using `affine.apply` or + // `affine.min`. + AffineMap ubMap = vecForOp.getUpperBoundMap(); + Value ub; + if (ubMap.getNumResults() == 1) + ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); + else + ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); + // Then we compute the number of (original) iterations left in the loop. + AffineExpr subExpr = + state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); + Value itersLeft = + makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr), + {ub, vecForOp.getInductionVar()}); + // If the affine maps were successfully composed then `ub` is unneeded. + if (ub.use_empty()) + ub.getDefiningOp()->erase(); + // Finally we create the mask. + Type maskTy = VectorType::get(state.strategy->vectorSizes, + state.builder.getIntegerType(1)); + Value mask = + state.builder.create(loc, maskTy, itersLeft); + + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n" + << itersLeft << "\n" + << mask << "\n"); + + state.vecLoopToMask[vecForOp] = mask; + return mask; +} + /// Returns true if the provided value is vector uniform given the vectorization /// strategy. // TODO: For now, only values that are invariants to all the loops in the @@ -1023,26 +1224,40 @@ return transfer; } +/// Returns true if `value` is a constant equal to the neutral element of the +/// given vectorizable reduction. +static bool isNeutralElementConst(const ReductionInfo *reduction, Value value, + VectorizationState &state) { + Type scalarTy = value.getType(); + if (!VectorType::isValidElementType(scalarTy)) + return false; + Attribute valueAttr = + reduction->getNeutralElementAttr(scalarTy, state.builder); + if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) + return constOp.value() == valueAttr; + return false; +} + /// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is /// created and registered as replacement for the scalar loop. The builder's /// insertion point is set to the new loop's body so that subsequent vectorized /// operations are inserted into the new loop. If the loop is a vector /// dimension, the step of the newly created loop will reflect the vectorization /// factor used to vectorized that dimension. -// TODO: Add support for 'iter_args'. Related operands and results will be -// vectorized at this point. static Operation *vectorizeAffineForOp(AffineForOp forOp, VectorizationState &state) { const VectorizationStrategy &strategy = *state.strategy; auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); - // We only support 'iter_args' when the loop is not one of the vector - // dimensions. - // TODO: Support vector dimension loops. They require special handling: - // generate horizontal reduction, last-value extraction, etc. - if (forOp.getNumIterOperands() > 0 && isLoopVecDim) + // TODO: Vectorization of reduction loops is not supported for non-unit steps. + if (isLoopVecDim && forOp.getNumIterOperands() > 0 && forOp.getStep() != 1) { + LLVM_DEBUG( + dbgs() + << "\n[early-vect]+++++ unsupported step size for reduction loop: " + << forOp.getStep() << "\n"); return nullptr; + } // If we are vectorizing a vector dimension, compute a new step for the new // vectorized loop using the vectorization factor for the vector dimension. @@ -1057,10 +1272,37 @@ newStep = forOp.getStep(); } + // Get information about recognized reduction kinds. + SmallVector reductions; + if (isLoopVecDim && forOp.getNumIterOperands() > 0) { + if (!strategy.reductionRecognizer) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ cannot vectorize loop with " + "iter_args: no recognizer provided\n"); + return nullptr; + } + bool allKnown = + getKnownReductions(forOp, *strategy.reductionRecognizer, reductions); + if (!allKnown) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ cannot vectorize loop with " + "iter_args: some reductions are not recognized\n"); + return nullptr; + } + } + // Vectorize 'iter_args'. SmallVector vecIterOperands; - for (auto operand : forOp.getIterOperands()) - vecIterOperands.push_back(vectorizeOperand(operand, state)); + if (!isLoopVecDim) { + for (auto operand : forOp.getIterOperands()) + vecIterOperands.push_back(vectorizeOperand(operand, state)); + } else { + // For reduction loops we need to pass a vector of neutral elements as an + // initial value of the accumulator. We will add the original initial value + // later. + for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) { + vecIterOperands.push_back(createInitialVector( + std::get<0>(redAndOperand), std::get<1>(redAndOperand), state)); + } + } auto vecForOp = state.builder.create( forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), @@ -1075,13 +1317,16 @@ // Register loop-related replacements: // 1) The new vectorized loop is registered as vector replacement of the // scalar loop. - // TODO: Support reductions along the vector dimension. // 2) The new iv of the vectorized loop is registered as scalar replacement // since a scalar copy of the iv will prevail in the vectorized loop. // TODO: A vector replacement will also be added in the future when // vectorization of linear ops is supported. // 3) The new 'iter_args' region arguments are registered as vector // replacements since they have been vectorized. + // 4) If the loop performs a reduction along the vector dimension, a + // `vector.reduction` or similar op is inserted for each resulting value + // of the loop and its scalar value replaces the corresponding scalar + // result of the loop. state.registerOpVectorReplacement(forOp, vecForOp); state.registerValueScalarReplacement(forOp.getInductionVar(), vecForOp.getInductionVar()); @@ -1090,12 +1335,35 @@ state.registerBlockArgVectorReplacement(std::get<0>(iterTuple), std::get<1>(iterTuple)); + if (isLoopVecDim) { + for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) { + // First, we reduce the vector returned from the loop into a scalar. + Value reducedRes = reductions[i]->createVectorReduction( + vecForOp.getResult(i), state.builder); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: " + << reducedRes); + // Then we combine it with the original (scalar) initial value unless it + // is equal to the neutral element of the reduction. + Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i); + Value finalRes = reducedRes; + if (!isNeutralElementConst(reductions[i], origInit, state)) + finalRes = reductions[i]->combine(reducedRes, origInit, state.builder); + state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes); + } + } + if (isLoopVecDim) state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second; // Change insertion point so that upcoming vectorized instructions are // inserted into the vectorized loop's body. state.builder.setInsertionPointToStart(vecForOp.getBody()); + + // If this is a reduction loop then we may need to create a mask to filter out + // garbage in the last iteration. + if (isLoopVecDim && forOp.getNumIterOperands() > 0) + createMask(vecForOp, state); + return vecForOp; } @@ -1133,11 +1401,35 @@ /// Vectorizes a yield operation by widening its types. The builder's insertion /// point is set after the vectorized parent op to continue vectorizing the -/// operations after the parent op. +/// operations after the parent op. When vectorizing a reduction loop a mask may +/// be used to prevent adding garbage values to the accumulator. static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, VectorizationState &state) { Operation *newYieldOp = widenOp(yieldOp, state); Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp(); + + // If there is a mask for this loop then we must prevent garbage values from + // being added to the accumulator by inserting `select` operations, for + // example: + // + // %res = addf %acc, %val : vector<128xf32> + // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> + // affine.yield %res_masked : vector<128xf32> + // + if (Value mask = state.vecLoopToMask.lookup(newParentOp)) { + state.builder.setInsertionPoint(newYieldOp); + for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) { + Value result = newYieldOp->getOperand(i); + Value iterArg = cast(newParentOp).getRegionIterArgs()[i]; + Value maskedResult = state.builder.create(result.getLoc(), mask, + result, iterArg); + LLVM_DEBUG( + dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " + << maskedResult); + newYieldOp->setOperand(i, maskedResult); + } + } + state.builder.setInsertionPointAfter(newParentOp); return newYieldOp; } @@ -1241,8 +1533,12 @@ auto opVecResult = rootLoop.walk([&](Operation *op) { LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op); Operation *vectorOp = vectorizeOneOperation(op, state); - if (!vectorOp) + if (!vectorOp) { + LLVM_DEBUG( + dbgs() << "[early-vect]+++++ failed vectorizing the operation: " + << *op << "\n"); return WalkResult::interrupt(); + } return WalkResult::advance(); }); @@ -1258,6 +1554,11 @@ return failure(); } + // Replace results of reduction loops with the scalar values computed using + // `vector.reduce` or similar ops. + for (auto resPair : state.loopResultScalarReplacement) + resPair.first.replaceAllUsesWith(resPair.second); + assert(state.opVectorReplacement.count(rootLoop) == 1 && "Expected vector replacement for loop nest"); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); @@ -1328,7 +1629,11 @@ /// vectorization order. static void vectorizeLoops(Operation *parentOp, DenseSet &loops, ArrayRef vectorSizes, - ArrayRef fastestVaryingPattern) { + ArrayRef fastestVaryingPattern, + const ReductionRecognizer *reductionRecognizer) { + assert((!reductionRecognizer || vectorSizes.size() == 1) && + "Vectorizing reductions is supported only for 1-D vectors"); + // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops. Optional pattern = makePattern(loops, vectorSizes.size(), fastestVaryingPattern); @@ -1359,6 +1664,7 @@ VectorizationStrategy strategy; // TODO: depending on profitability, elect to reduce the vector size. strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end()); + strategy.reductionRecognizer = reductionRecognizer; if (failed(analyzeProfitability(match.getMatchedChildren(), 1, patternDepth, &strategy))) { continue; @@ -1396,15 +1702,29 @@ return signalPassFailure(); } + if (vectorizeReductions && vectorSizes.size() != 1) { + f.emitError("Vectorizing reductions is supported only for 1-D vectors."); + return signalPassFailure(); + } + + // If 'vectorize-reduction=true' is provided, use the standard reduction + // recognizer, otherwise use the null reduction recognizer which rejects all + // reductions. + const ReductionRecognizer &stdRedRecognizer = StandardReductionRecognizer(); + const ReductionRecognizer &nullRedRecognizer = NullReductionRecognizer(); + const ReductionRecognizer &reductionRecognizer = + vectorizeReductions ? stdRedRecognizer : nullRedRecognizer; + DenseSet parallelLoops; - f.walk([¶llelLoops](AffineForOp loop) { - if (isLoopParallel(loop)) + f.walk([¶llelLoops, &reductionRecognizer](AffineForOp loop) { + if (isParallelReductionLoop(loop, reductionRecognizer)) parallelLoops.insert(loop); }); // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern); + vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern, + vectorizeReductions ? &reductionRecognizer : nullptr); } /// Verify that affine loops in 'loops' meet the nesting criteria expected by @@ -1452,12 +1772,17 @@ /// factor is applied inner-to-outer to the loops of each loop nest. /// 'fastestVaryingPattern' can be optionally used to provide a different loop /// vectorization order. +/// If `reductionRecognizer` is not null, recognized reduction loops may be +/// vectorized along the reduction dimension. +/// TODO: Vectorizing reductions is supported only for 1-D vectorization. void vectorizeAffineLoops(Operation *parentOp, DenseSet &loops, ArrayRef vectorSizes, - ArrayRef fastestVaryingPattern) { + ArrayRef fastestVaryingPattern, + const ReductionRecognizer *reductionRecognizer) { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern); + vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern, + reductionRecognizer); } /// External utility to vectorize affine loops from a single loop nest using an Index: mlir/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- mlir/lib/Transforms/Utils/CMakeLists.txt +++ mlir/lib/Transforms/Utils/CMakeLists.txt @@ -5,6 +5,7 @@ InliningUtils.cpp LoopFusionUtils.cpp LoopUtils.cpp + ReductionUtils.cpp RegionUtils.cpp Utils.cpp @@ -23,4 +24,5 @@ MLIRPass MLIRRewrite MLIRStandard + MLIRVector ) Index: mlir/lib/Transforms/Utils/ReductionUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Transforms/Utils/ReductionUtils.cpp @@ -0,0 +1,195 @@ +//===- ReductionUtils.cpp - Reduction-related utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for recognizing and manipulating reductions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/ReductionUtils.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "reduction-utils" + +using namespace mlir; +using llvm::dbgs; + +/// Builds a reduction of the vector value into a scalar. +template +Value StandardReductionInfo::createVectorReduction( + Value vector, OpBuilder &builder) const { + Type scalarType = vector.getType().cast().getElementType(); + return builder.create(vector.getLoc(), scalarType, + builder.getStringAttr(kindString), + vector, ValueRange{}); +} + +/// Combines two scalar values. +template +Value StandardReductionInfo::combine(Value lhs, Value rhs, + OpBuilder &builder) const { + return builder.create(lhs.getLoc(), lhs, rhs); +} + +/// Returns an instance of this class. +template +ReductionInfo *StandardReductionInfo::get() { + static StandardReductionInfo reduction; + return &reduction; +} + +/// The string used as the kind for `vector.reduction`. +template <> +const char *const StandardReductionInfo::kindString = "add"; +template <> +const char *const StandardReductionInfo::kindString = "add"; +template <> +const char *const StandardReductionInfo::kindString = "mul"; +template <> +const char *const StandardReductionInfo::kindString = "mul"; + +template <> +Attribute +StandardReductionInfo::getNeutralElementAttr(Type elemType, + OpBuilder &builder) const { + return builder.getZeroAttr(elemType); +} + +template <> +Attribute +StandardReductionInfo::getNeutralElementAttr(Type elemType, + OpBuilder &builder) const { + return builder.getZeroAttr(elemType); +} + +template <> +Attribute +StandardReductionInfo::getNeutralElementAttr(Type elemType, + OpBuilder &builder) const { + return builder.getIntegerAttr(elemType, 1); +} + +template <> +Attribute +StandardReductionInfo::getNeutralElementAttr(Type elemType, + OpBuilder &builder) const { + return builder.getFloatAttr(elemType, 1); +} + +/// Recognize a standard reduction given the scalar value from the previous +/// iteration `arg` and the value passed to the next iteration `yielded`. +/// Example: +/// +/// %sum = affine.for %j = 0 to 512 iter_args(%arg = %cst0) -> (f32) { +/// %ld = affine.load %in[%j] : memref<512xf32> +/// %yielded = addf %arg, %ld : f32 +/// affine.yield %yielded : f32 +/// } +/// +const ReductionInfo * +StandardReductionRecognizer::recognize(BlockArgument arg, Value yielded) const { + // Check that there is a single combining operation that doesn't leak + // information. This should be suitable for most kinds of reductions, except + // for min and max that need two operations (compare and select). + Operation *combiner = getSingleOpCombiner(arg, yielded); + if (!combiner) { + LLVM_DEBUG(dbgs() << "Reduction not recognized, the combiner doesn't use " + "`arg` or information is leaked:" + << "\narg: " << arg << "\nyielded value: " << yielded); + return nullptr; + } + + const ReductionInfo *reductionInfo = recognizeSingleOpCombiner(arg, yielded); + if (!reductionInfo) + LLVM_DEBUG(dbgs() << "Reduction not recognized, unknown op: " << *combiner); + + return reductionInfo; +} + +/// Checks if `arg` is used exclusively by the operation generating the yielded +/// value. On success the combining operation is returned. +Operation * +StandardReductionRecognizer::getSingleOpCombiner(BlockArgument arg, + Value yielded) const { + // This is the combining operation like addf or muli. + Operation *combiner = yielded.getDefiningOp(); + if (!combiner) + return nullptr; + // If `arg` or `yield` are used elsewhere then we may leak intermediate values + // (like partial sums) making the reduction loop unvectorizable and + // unparallelizable (with simple methods at least). + if (!llvm::hasSingleElement(arg.getUses()) || + !llvm::hasSingleElement(yielded.getUses())) + return nullptr; + // The only user of `arg` must be the combining operation. + if (*std::begin(arg.getUsers()) != combiner) + return nullptr; + return combiner; +} + +/// Recognizes the simplest case when the reduction uses one operation. +const ReductionInfo * +StandardReductionRecognizer::recognizeSingleOpCombiner(BlockArgument arg, + Value yielded) const { + Operation *op = yielded.getDefiningOp(); + if (isa(op)) + return StandardReductionInfo::get(); + if (isa(op)) + return StandardReductionInfo::get(); + if (isa(op)) + return StandardReductionInfo::get(); + if (isa(op)) + return StandardReductionInfo::get(); + return nullptr; +} + +/// Returns true if `forOp` is a parallel loop possibly implementing known +/// reductions via loop-carried variables (iter_args). Reductions are considered +/// known (and parallel) if they are recognized by `reductionRecognizer`. +bool mlir::isParallelReductionLoop( + AffineForOp forOp, const ReductionRecognizer &reductionRecognizer) { + // Check that there are no loop-carried memory dependences. + if (!isLoopParallel(forOp, /*ignoreIterArgs=*/true)) + return false; + + // Check that all iteration arguments implement known reductions. + auto iterArgs = forOp.getRegionIterArgs(); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + for (auto argAndYield : llvm::zip(iterArgs, yieldedValues)) { + const ReductionInfo *reduction = reductionRecognizer.recognize( + std::get<0>(argAndYield), std::get<1>(argAndYield)); + if (!reduction) + return false; + } + + return true; +} + +/// Populates `reductions` with the information about known reductions +/// implemented by `forOp`. Reductions are considered known if they are +/// recognized by `reductionRecognizer`. Returns `true` if all iteration +/// variables implement recognizable reductions and `false` otherwise. +bool mlir::getKnownReductions( + AffineForOp forOp, const ReductionRecognizer &reductionRecognizer, + SmallVectorImpl &reductions) { + bool allKnown = true; + auto iterArgs = forOp.getRegionIterArgs(); + auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); + for (auto argAndYield : llvm::zip(iterArgs, yieldedValues)) { + const ReductionInfo *reduction = reductionRecognizer.recognize( + std::get<0>(argAndYield), std::get<1>(argAndYield)); + if (!reduction) + allKnown = false; + reductions.push_back(reduction); + } + + return allKnown; +} Index: mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir =================================================================== --- mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -590,3 +590,24 @@ // CHECK: } // CHECK: vector.transfer_write %[[last_val]], %{{.*}} : vector<128xf32>, memref<256xf32> // CHECK: } + +// ----- + +// The inner reduction loop '%j' is not vectorized if we do not request +// reduction vectorization. + +func @vec_vecdim_reduction_rejected(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vec_vecdim_reduction_rejected +// CHECK-NOT: vector Index: mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -0,0 +1,468 @@ +// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 vectorize-reductions=true" -split-input-file | FileCheck %s + +// The inner reduction loop '%j' is vectorized. + +func @vecdim_reduction(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// The inner reduction loop '%j' is vectorized. (The order of addf's operands is +// different than in the previous test case). + +func @vecdim_reduction_comm(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %ld, %red_iter : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction_comm +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[ld]], %[[red_iter]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// The inner reduction loop '%j' is vectorized. Transforming the input before +// performing the accumulation doesn't cause any problem. + +func @vecdim_reduction_expsin(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %sin = math.sin %ld : f32 + %exp = math.exp %sin : f32 + %add = addf %red_iter, %exp : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction_expsin +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[sin:.*]] = math.sin %[[ld]] +// CHECK: %[[exp:.*]] = math.exp %[[sin]] +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[exp]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// Two reductions at the same time. The inner reduction loop '%j' is vectorized. + +func @two_vecdim_reductions(%in: memref<256x512xf32>, %out_sum: memref<256xf32>, %out_prod: memref<256xf32>) { + %cst = constant 1.000000e+00 : f32 + affine.for %i = 0 to 256 { + // Note that we pass the same constant '1.0' as initial values for both + // reductions. + %sum, %prod = affine.for %j = 0 to 512 iter_args(%part_sum = %cst, %part_prod = %cst) -> (f32, f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %part_sum, %ld : f32 + %mul = mulf %part_prod, %ld : f32 + affine.yield %add, %mul : f32, f32 + } + affine.store %sum, %out_sum[%i] : memref<256xf32> + affine.store %prod, %out_prod[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @two_vecdim_reductions +// CHECK: %[[cst:.*]] = constant 1.000000e+00 : f32 +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vone:.*]] = constant dense<1.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]]:2 = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[part_sum:.*]] = %[[vzero]], %[[part_prod:.*]] = %[[vone]]) -> (vector<128xf32>, vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[part_sum]], %[[ld]] : vector<128xf32> +// CHECK: %[[mul:.*]] = mulf %[[part_prod]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[add]], %[[mul]] : vector<128xf32>, vector<128xf32> +// CHECK: } +// CHECK: %[[nonfinal_sum:.*]] = vector.reduction "add", %[[vred:.*]]#0 : vector<128xf32> into f32 +// Note that to compute the final sum we need to add the original initial value +// (%cst) since it is not zero. +// CHECK: %[[final_sum:.*]] = addf %[[nonfinal_sum]], %[[cst]] : f32 +// For the final product we don't need to do this additional step because the +// initial value equals to 1 (the neutral element for multiplication). +// CHECK: %[[final_prod:.*]] = vector.reduction "mul", %[[vred:.*]]#1 : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: affine.store %[[final_prod]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// The integer case. + +func @two_vecdim_reductions_int(%in: memref<256x512xi64>, %out_sum: memref<256xi64>, %out_prod: memref<256xi64>) { + %cst0 = constant 0 : i64 + %cst1 = constant 1 : i64 + affine.for %i = 0 to 256 { + %sum, %prod = affine.for %j = 0 to 512 iter_args(%part_sum = %cst0, %part_prod = %cst1) -> (i64, i64) { + %ld = affine.load %in[%i, %j] : memref<256x512xi64> + %add = addi %part_sum, %ld : i64 + %mul = muli %part_prod, %ld : i64 + affine.yield %add, %mul : i64, i64 + } + affine.store %sum, %out_sum[%i] : memref<256xi64> + affine.store %prod, %out_prod[%i] : memref<256xi64> + } + return +} + +// CHECK-LABEL: @two_vecdim_reductions +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0> : vector<128xi64> +// CHECK: %[[vone:.*]] = constant dense<1> : vector<128xi64> +// CHECK: %[[vred:.*]]:2 = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[part_sum:.*]] = %[[vzero]], %[[part_prod:.*]] = %[[vone]]) -> (vector<128xi64>, vector<128xi64>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi64>, vector<128xi64> +// CHECK: %[[add:.*]] = addi %[[part_sum]], %[[ld]] : vector<128xi64> +// CHECK: %[[mul:.*]] = muli %[[part_prod]], %[[ld]] : vector<128xi64> +// CHECK: affine.yield %[[add]], %[[mul]] : vector<128xi64>, vector<128xi64> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]]#0 : vector<128xi64> into i64 +// CHECK: %[[final_prod:.*]] = vector.reduction "mul", %[[vred:.*]]#1 : vector<128xi64> into i64 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xi64> +// CHECK: affine.store %[[final_prod]], %{{.*}} : memref<256xi64> +// CHECK: } + +// ----- + +// The outer reduction loop '%j' is vectorized. + +func @vecdim_reduction_nested(%in: memref<256x512xf32>, %out: memref<1xf32>) { + %cst = constant 0.000000e+00 : f32 + %outer_red = affine.for %j = 0 to 512 iter_args(%outer_iter = %cst) -> (f32) { + %inner_red = affine.for %i = 0 to 256 iter_args(%inner_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %inner_iter, %ld : f32 + affine.yield %add : f32 + } + %outer_add = addf %outer_iter, %inner_red : f32 + affine.yield %outer_add : f32 + } + affine.store %outer_red, %out[0] : memref<1xf32> + return +} + +// CHECK-LABEL: @vecdim_reduction_nested +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[outer_red:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[outer_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[inner_red:.*]] = affine.for %{{.*}} = 0 to 256 iter_args(%[[inner_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[inner_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> +// CHECK: } +// CHECK: %[[outer_add:.*]] = addf %[[outer_iter]], %[[inner_red]] : vector<128xf32> +// CHECK: affine.yield %[[outer_add]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[outer_red:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<1xf32> + +// ----- + +// The inner reduction loop '%j' computes partial sums as a side effect and +// is not vectorized. + +func @vecdim_partial_sums_1_rejected(%in: memref<256x512xf32>, %out_sum: memref<256xf32>, %out_prod: memref<256xf32>, %out_partsum: memref<256x512xf32>) { + %cst = constant 1.000000e+00 : f32 + affine.for %i = 0 to 256 { + %sum, %prod = affine.for %j = 0 to 512 iter_args(%part_sum = %cst, %part_prod = %cst) -> (f32, f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %part_sum, %ld : f32 + %mul = mulf %part_prod, %ld : f32 + affine.store %add, %out_partsum[%i, %j] : memref<256x512xf32> + affine.yield %add, %mul : f32, f32 + } + affine.store %sum, %out_sum[%i] : memref<256xf32> + affine.store %prod, %out_prod[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_partial_sums_1_rejected +// CHECK-NOT: vector + +// ----- + +// The inner reduction loop '%j' computes partial sums as a side effect and +// is not vectorized. + +func @vecdim_partial_sums_2_rejected(%in: memref<256x512xf32>, %out_sum: memref<256xf32>, %out_prod: memref<256xf32>, %out_partsum: memref<256x512xf32>) { + %cst = constant 1.000000e+00 : f32 + affine.for %i = 0 to 256 { + %sum, %prod = affine.for %j = 0 to 512 iter_args(%part_sum = %cst, %part_prod = %cst) -> (f32, f32) { + affine.store %part_sum, %out_partsum[%i, %j] : memref<256x512xf32> + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %part_sum, %ld : f32 + %mul = mulf %part_prod, %ld : f32 + affine.yield %add, %mul : f32, f32 + } + affine.store %sum, %out_sum[%i] : memref<256xf32> + affine.store %prod, %out_prod[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_partial_sums_2_rejected +// CHECK-NOT: vector + +// ----- + +// The inner reduction loop '%j' performs an unknown reduction operation and is +// not vectorized. + +func @vecdim_unknown_reduction_rejected(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 1.000000e+00 : f32 + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %add = addf %red_iter, %red_iter : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[0] : memref<256xf32> + return +} + +// CHECK-LABEL: @vecdim_unknown_reduction_rejected +// CHECK-NOT: vector + +// ----- + +// The inner reduction loop '%j' doesn't perform any operation which is not +// recognized as a standard reduction. + +func @vecdim_none_reduction_rejected(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 1.000000e+00 : f32 + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + affine.yield %red_iter : f32 + } + affine.store %final_red, %out[0] : memref<256xf32> + return +} + +// CHECK-LABEL: @vecdim_none_reduction_rejected +// CHECK-NOT: vector + +// ----- + +// The number of iterations is not divisable by the vector size, so a mask has +// to be applied to the last update of the accumulator. + +func @vecdim_reduction_masked(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 500 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK: #[[$map0:.*]] = affine_map<([[d0:.*]]) -> (-[[d0]] + 500)> +// CHECK-LABEL: @vecdim_reduction_masked +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %[[iv:.*]] = 0 to 500 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[elems_left:.*]] = affine.apply #[[$map0]](%[[iv]]) +// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// The number of iteration is not known, so a mask has to be applied. + +func @vecdim_reduction_masked_unknown_ub(%in: memref<256x512xf32>, %out: memref<256xf32>, %bnd: index) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to %bnd iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK: #[[$map1:.*]] = affine_map<([[d0:.*]]){{\[}}[[s0:.*]]{{\]}} -> (-[[d0]] + [[s0]])> +// CHECK-LABEL: @vecdim_reduction_masked_unknown_ub +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %[[iv:.*]] = 0 to %[[bnd:.*]] step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[elems_left:.*]] = affine.apply #[[$map1]](%[[iv]])[%[[bnd]]] +// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_sum:.*]] = vector.reduction "add", %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +// The lower bound is nonzero, but the number of iterations is divisible by the +// vector size, so masking is not needed. + +func @vecdim_reduction_nonzero_lb(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 127 to 511 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction_nonzero_lb +// CHECK: %{{.*}} = affine.for %{{.*}} = 127 to 511 step 128 iter_args({{.*}}) -> (vector<128xf32>) { +// CHECK-NOT: vector.create_mask + +// ----- + +// The lower bound is unknown, so we need to create a mask. + +func @vecdim_reduction_masked_unknown_lb(%in: memref<256x512xf32>, %out: memref<256xf32>, %lb: index) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = %lb to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK: #[[$map2:.*]] = affine_map<([[d0:.*]]) -> (-[[d0]] + 512)> +// CHECK-LABEL: @vecdim_reduction_masked_unknown_lb +// CHECK: %{{.*}} = affine.for %[[iv:.*]] = %[[lb:.*]] to 512 step 128 iter_args(%[[red_iter:.*]] = {{.*}}) -> (vector<128xf32>) { +// CHECK: %[[elems_left:.*]] = affine.apply #[[$map2]](%[[iv]]) +// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: affine.yield %[[new_acc]] : vector<128xf32> + +// ----- + +// The upper bound is a minimum expression. + +func @vecdim_reduction_complex_ub(%in: memref<256x512xf32>, %out: memref<256xf32>, %M: index, %N: index) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to min affine_map<(d0, d1) -> (d0, d1*2)>(%M, %N) iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK: #[[$map3:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]], [[d1]] * 2)> +// CHECK: #[[$map3_sub:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]] - [[d1]])> +// CHECK-LABEL: @vecdim_reduction_complex_ub +// CHECK: %{{.*}} = affine.for %[[iv:.*]] = 0 to min #[[$map3]](%[[M:.*]], %[[N:.*]]) step 128 iter_args(%[[red_iter:.*]] = {{.*}}) -> (vector<128xf32>) { +// CHECK: %[[ub:.*]] = affine.min #[[$map3]](%[[M]], %[[N]]) +// CHECK: %[[elems_left:.*]] = affine.apply #[[$map3_sub]](%[[ub]], %[[iv]]) +// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: affine.yield %[[new_acc]] : vector<128xf32> + +// ----- + +// The same mask is applied to both reductions. + +func @vecdim_two_reductions_masked(%in: memref<256x512xf32>, %out: memref<512xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %final_sum, %final_expsum = affine.for %j = 0 to 500 iter_args(%sum_iter = %cst, %expsum_iter = %cst) -> (f32, f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %exp = math.exp %ld : f32 + %add = addf %sum_iter, %ld : f32 + %eadd = addf %expsum_iter, %exp : f32 + affine.yield %add, %eadd : f32, f32 + } + affine.store %final_sum, %out[2*%i] : memref<512xf32> + affine.store %final_expsum, %out[2*%i + 1] : memref<512xf32> + } + return +} + +// CHECK: #[[$map4:.*]] = affine_map<([[d0:.*]]) -> (-[[d0]] + 500)> +// CHECK-LABEL: @vecdim_two_reductions_masked +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %{{.*}} = affine.for %[[iv:.*]] = 0 to 500 step 128 iter_args(%[[sum_iter:.*]] = {{.*}}, %[[esum_iter:.*]] = {{.*}}) -> (vector<128xf32>, vector<128xf32>) { +// CHECK: %[[elems_left:.*]] = affine.apply #[[$map4]](%[[iv]]) +// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[exp:.*]] = math.exp %[[ld]] : vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[sum_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[eadd:.*]] = addf %[[esum_iter]], %[[exp]] : vector<128xf32> +// CHECK: %[[new_acc:.*]] = select %[[mask]], %[[add]], %[[sum_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[new_eacc:.*]] = select %[[mask]], %[[eadd]], %[[esum_iter]] : vector<128xi1>, vector<128xf32> +// CHECK: affine.yield %[[new_acc]], %[[new_eacc]] : vector<128xf32> +// CHECK: } Index: mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction_2d.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction_2d.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=32,256 test-fastest-varying=1,0 vectorize-reductions=true" -verify-diagnostics + +// TODO: Vectorization of reduction loops along the reduction dimension is not +// supported for higher-rank vectors yet, so we are just checking that an +// error message is produced. + +// expected-error@+1 {{Vectorizing reductions is supported only for 1-D vectors}} +func @vecdim_reduction_2d(%in: memref<256x512x1024xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + %sum_j = affine.for %j = 0 to 512 iter_args(%red_iter_j = %cst) -> (f32) { + %sum_k = affine.for %k = 0 to 1024 iter_args(%red_iter_k = %cst) -> (f32) { + %ld = affine.load %in[%i, %j, %k] : memref<256x512x1024xf32> + %add = addf %red_iter_k, %ld : f32 + affine.yield %add : f32 + } + %add = addf %red_iter_j, %sum_k : f32 + affine.yield %add : f32 + } + affine.store %sum_j, %out[%i] : memref<256xf32> + } + return +} +