diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -19,6 +19,7 @@ class AffineApplyOp; class AffineForOp; class AffineMap; +class Block; class Location; class OpBuilder; class Operation; @@ -98,8 +99,10 @@ /// Note that loopToVectorDim is a whole function map from which only enclosing /// loop information is extracted. /// -/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at -/// most one invariant index along each AffineForOp of `loopToVectorDim`). +/// Prerequisites: `indices` belong to a vectorizable load or store operation +/// (i.e. at most one invariant index along each AffineForOp of +/// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized +/// load or store operation. /// /// Example 1: /// The following MLIR snippet: @@ -151,7 +154,10 @@ /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap -makePermutationMap(Operation *op, ArrayRef indices, +makePermutationMap(Block *insertPoint, ArrayRef indices, + const DenseMap &loopToVectorDim); +AffineMap +makePermutationMap(Operation *insertPoint, ArrayRef indices, const DenseMap &loopToVectorDim); /// Build the default minor identity map suitable for a vector transfer. This diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -14,28 +14,13 @@ #include "PassDetail.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/NestedMatcher.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Support/LLVM.h" -#include "mlir/Transforms/FoldUtils.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" using namespace mlir; using namespace vector; @@ -252,61 +237,38 @@ /// fastest varying) ; /// 2. analyzing those patterns for profitability (TODO: and /// interference); -/// 3. Then, for each pattern in order: -/// a. applying iterative rewriting of the loop and the load operations in -/// inner-to-outer order. Rewriting is implemented by coarsening the loops -/// and turning load operations into opaque vector.transfer_read ops; -/// b. keeping track of the load operations encountered as "roots" and the -/// store operations as "terminals"; -/// c. traversing the use-def chains starting from the roots and iteratively -/// propagating vectorized values. Scalar values that are encountered -/// during this process must come from outside the scope of the current -/// pattern (TODO: enforce this and generalize). Such a scalar value -/// is vectorized only if it is a constant (into a vector splat). The -/// non-constant case is not supported for now and results in the pattern -/// failing to vectorize; -/// d. performing a second traversal on the terminals (store ops) to -/// rewriting the scalar value they write to memory into vector form. -/// If the scalar value has been vectorized previously, we simply replace -/// it by its vector form. Otherwise, if the scalar value is a constant, -/// it is vectorized into a splat. In all other cases, vectorization for -/// the pattern currently fails. -/// e. if everything under the root AffineForOp in the current pattern -/// vectorizes properly, we commit that loop to the IR. Otherwise we -/// discard it and restore a previously cloned version of the loop. Thanks -/// to the recursive scoping nature of matchers and captured patterns, -/// this is transparently achieved by a simple RAII implementation. -/// f. vectorization is applied on the next pattern in the list. Because +/// 3. then, for each pattern in order: +/// a. applying iterative rewriting of the loops and all their nested +/// operations in topological order. Rewriting is implemented by +/// coarsening the loops and converting operations and operands to their +/// vector forms. Processing operations in topological order is relatively +/// simple due to the structured nature of the control-flow +/// representation. This order ensures that all the operands of a given +/// operation have been vectorized before the operation itself in a single +/// traversal, except for operands defined outside of the loop nest. The +/// algorithm can convert the following operations to their vector form: +/// * Affine load and store operations are converted to opaque vector +/// transfer read and write operations. +/// * Scalar constant operations/operands are converted to vector +/// constant operations (splat). +/// * Uniform operands (only operands defined outside of the loop nest, +/// for now) are broadcasted to a vector. +/// TODO: Support more uniform cases. +/// * The remaining operations in the loop nest are vectorized by +/// widening their scalar types to vector types. +/// * TODO: Add vectorization support for loops with 'iter_args' and +/// more complex loops with divergent lbs and/or ubs. +/// b. if everything under the root AffineForOp in the current pattern +/// is vectorized properly, we commit that loop to the IR and remove the +/// scalar loop. Otherwise, we discard the vectorized loop and keep the +/// original scalar loop. +/// c. vectorization is applied on the next pattern in the list. Because /// pattern interference avoidance is not yet implemented and that we do /// not support further vectorizing an already vector load we need to /// re-verify that the pattern is still vectorizable. This is expected to /// make cost models more difficult to write and is subject to improvement /// in the future. /// -/// Points c. and d. above are worth additional comment. In most passes that -/// do not change the type of operands, it is usually preferred to eagerly -/// `replaceAllUsesWith`. Unfortunately this does not work for vectorization -/// because during the use-def chain traversal, all the operands of an operation -/// must be available in vector form. Trying to propagate eagerly makes the IR -/// temporarily invalid and results in errors such as: -/// `vectorize.mlir:308:13: error: 'addf' op requires the same type for all -/// operands and results -/// %s5 = addf %a5, %b5 : f32` -/// -/// Lastly, we show a minimal example for which use-def chains rooted in load / -/// vector.transfer_read are not enough. This is what motivated splitting -/// terminal processing out of the use-def chains starting from loads. In the -/// following snippet, there is simply no load:: -/// ```mlir -/// func @fill(%A : memref<128xf32>) -> () { -/// %f1 = constant 1.0 : f32 -/// affine.for %i0 = 0 to 32 { -/// affine.store %f1, %A[%i0] : memref<128xf32, 0> -/// } -/// return -/// } -/// ``` -/// /// Choice of loop transformation to support the algorithm: /// ======================================================= /// The choice of loop transformation to apply for coarsening vectorized loops @@ -527,7 +489,6 @@ #define DEBUG_TYPE "early-vect" using llvm::dbgs; -using llvm::SetVector; /// Forward declaration. static FilterFunctionType @@ -632,199 +593,196 @@ namespace { struct VectorizationState { - /// Adds an entry of pre/post vectorization operations in the state. - void registerReplacement(Operation *key, Operation *value); - /// When the current vectorization pattern is successful, this erases the - /// operations that were marked for erasure in the proper order and resets - /// the internal state for the next pattern. - void finishVectorizationPattern(); - - // In-order tracking of original Operation that have been vectorized. - // Erase in reverse order. - SmallVector toErase; - // Set of Operation that have been vectorized (the values in the - // vectorizationMap for hashed access). The vectorizedSet is used in - // particular to filter the operations that have already been vectorized by - // this pattern, when iterating over nested loops in this pattern. - DenseSet vectorizedSet; - // Map of old scalar Operation to new vectorized Operation. - DenseMap vectorizationMap; - // Map of old scalar Value to new vectorized Value. - DenseMap replacementMap; + + VectorizationState(MLIRContext *context) : builder(context) {} + + /// Registers the vector replacement of a scalar operation and its result + /// values. Both operations must have the same number of results. + /// + /// This utility is used to register the replacement for the vast majority of + /// the vectorized operations. + /// + /// Example: + /// * 'replaced': %0 = addf %1, %2 : f32 + /// * 'replacement': %0 = addf %1, %2 : vector<128xf32> + void registerOpVectorReplacement(Operation *replaced, Operation *replacement); + + /// Registers the vector replacement of a scalar value. The replacement + /// operation should have a single result, which replaces the scalar value. + /// + /// This utility is used to register the vector replacement of block arguments + /// and operation results which are not directly vectorized (i.e., their + /// scalar version still exists after vectorization), like uniforms. + /// + /// Example: + /// * 'replaced': block argument or operation outside of the vectorized + /// loop. + /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> + void registerValueVectorReplacement(Value replaced, Operation *replacement); + + /// Registers the scalar replacement of a scalar value. 'replacement' must be + /// scalar. Both values must be block arguments. Operation results should be + /// replaced using the 'registerOp*' utilitites. + /// + /// This utility is used to register the replacement of block arguments + /// that are within the loop to be vectorized and will continue being scalar + /// within the vector loop. + /// + /// Example: + /// * 'replaced': induction variable of a loop to be vectorized. + /// * 'replacement': new induction variable in the new vector loop. + void registerValueScalarReplacement(BlockArgument replaced, + BlockArgument replacement); + + /// Returns in 'replacedVals' the scalar replacement for values in + /// 'inputVals'. + void getScalarValueReplacementsFor(ValueRange inputVals, + SmallVectorImpl &replacedVals); + + /// Erases the scalar loop nest after its successful vectorization. + void finishVectorizationPattern(AffineForOp rootLoop); + + // Used to build and insert all the new operations created. The insertion + // point is preserved and updated along the vectorization process. + OpBuilder builder; + + // Maps input scalar operations to their vector counterparts. + DenseMap opVectorReplacement; + // Maps input scalar values to their vector counterparts. + BlockAndValueMapping valueVectorReplacement; + // Maps input scalar values to their new scalar counterparts in the vector + // loop nest. + BlockAndValueMapping valueScalarReplacement; + + // Maps the newly created vector loops to their vector dimension. + DenseMap vecLoopToVecDim; + // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; - // Use-def roots. These represent the starting points for the worklist in the - // vectorizeNonTerminals function. They consist of the subset of load - // operations that have been vectorized. They can be retrieved from - // `vectorizationMap` but it is convenient to keep track of them in a separate - // data structure. - DenseSet roots; - // Terminal operations for the worklist in the vectorizeNonTerminals - // function. They consist of the subset of store operations that have been - // vectorized. They can be retrieved from `vectorizationMap` but it is - // convenient to keep track of them in a separate data structure. Since they - // do not necessarily belong to use-def chains starting from loads (e.g - // storing a constant), we need to handle them in a post-pass. - DenseSet terminals; - // Checks that the type of `op` is AffineStoreOp and adds it to the terminals - // set. - void registerTerminal(Operation *op); - // Folder used to factor out constant creation. - OperationFolder *folder; private: - void registerReplacement(Value key, Value value); + /// Internal implementation to map input scalar values to new vector or scalar + /// values. + void registerValueVectorReplacementImpl(Value replaced, Value replacement); + void registerValueScalarReplacementImpl(Value replaced, Value replacement); }; } // end namespace -void VectorizationState::registerReplacement(Operation *key, Operation *value) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); - LLVM_DEBUG(key->print(dbgs())); - LLVM_DEBUG(dbgs() << " into "); - LLVM_DEBUG(value->print(dbgs())); - assert(key->getNumResults() == 1 && "already registered"); - assert(value->getNumResults() == 1 && "already registered"); - assert(vectorizedSet.count(value) == 0 && "already registered"); - assert(vectorizationMap.count(key) == 0 && "already registered"); - toErase.push_back(key); - vectorizedSet.insert(value); - vectorizationMap.insert(std::make_pair(key, value)); - registerReplacement(key->getResult(0), value->getResult(0)); - if (isa(key)) { - assert(roots.count(key) == 0 && "root was already inserted previously"); - roots.insert(key); - } +/// Registers the vector replacement of a scalar operation and its result +/// values. Both operations must have the same number of results. +/// +/// This utility is used to register the replacement for the vast majority of +/// the vectorized operations. +/// +/// Example: +/// * 'replaced': %0 = addf %1, %2 : f32 +/// * 'replacement': %0 = addf %1, %2 : vector<128xf32> +void VectorizationState::registerOpVectorReplacement(Operation *replaced, + Operation *replacement) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"); + LLVM_DEBUG(dbgs() << *replaced << "\n"); + LLVM_DEBUG(dbgs() << "into\n"); + LLVM_DEBUG(dbgs() << *replacement << "\n"); + + assert(replaced->getNumResults() <= 1 && "Unsupported multi-result op"); + assert(replaced->getNumResults() == replacement->getNumResults() && + "Unexpected replaced and replacement results"); + assert(opVectorReplacement.count(replaced) == 0 && "already registered"); + opVectorReplacement[replaced] = replacement; + + if (replaced->getNumResults() > 0) + registerValueVectorReplacementImpl(replaced->getResult(0), + replacement->getResult(0)); } -void VectorizationState::registerTerminal(Operation *op) { - assert(isa(op) && "terminal must be a AffineStoreOp"); - assert(terminals.count(op) == 0 && - "terminal was already inserted previously"); - terminals.insert(op); +/// Registers the vector replacement of a scalar value. The replacement +/// operation should have a single result, which replaces the scalar value. +/// +/// This utility is used to register the vector replacement of block arguments +/// and operation results which are not directly vectorized (i.e., their +/// scalar version still exists after vectorization), like uniforms. +/// +/// Example: +/// * 'replaced': block argument or operation outside of the vectorized loop. +/// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> +void VectorizationState::registerValueVectorReplacement( + Value replaced, Operation *replacement) { + assert(replacement->getNumResults() == 1 && + "Expected single-result replacement"); + if (Operation *defOp = replaced.getDefiningOp()) + registerOpVectorReplacement(defOp, replacement); + else + registerValueVectorReplacementImpl(replaced, replacement->getResult(0)); } -void VectorizationState::finishVectorizationPattern() { - while (!toErase.empty()) { - auto *op = toErase.pop_back_val(); - LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); - LLVM_DEBUG(op->print(dbgs())); - op->erase(); - } +void VectorizationState::registerValueVectorReplacementImpl(Value replaced, + Value replacement) { + assert(!valueVectorReplacement.contains(replaced) && + "Vector replacement already registered"); + assert(replacement.getType().isa() && + "Expected vector type in vector replacement"); + valueVectorReplacement.map(replaced, replacement); } -void VectorizationState::registerReplacement(Value key, Value value) { - assert(replacementMap.count(key) == 0 && "replacement already registered"); - replacementMap.insert(std::make_pair(key, value)); +/// Registers the scalar replacement of a scalar value. 'replacement' must be +/// scalar. Both values must be block arguments. Operation results should be +/// replaced using the 'registerOp*' utilitites. +/// +/// This utility is used to register the replacement of block arguments +/// that are within the loop to be vectorized and will continue being scalar +/// within the vector loop. +/// +/// Example: +/// * 'replaced': induction variable of a loop to be vectorized. +/// * 'replacement': new induction variable in the new vector loop. +void VectorizationState::registerValueScalarReplacement( + BlockArgument replaced, BlockArgument replacement) { + registerValueScalarReplacementImpl(replaced, replacement); +} + +void VectorizationState::registerValueScalarReplacementImpl(Value replaced, + Value replacement) { + assert(!valueScalarReplacement.contains(replaced) && + "Scalar value replacement already registered"); + assert(!replacement.getType().isa() && + "Expected scalar type in scalar replacement"); + valueScalarReplacement.map(replaced, replacement); +} + +/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'. +void VectorizationState::getScalarValueReplacementsFor( + ValueRange inputVals, SmallVectorImpl &replacedVals) { + for (Value inputVal : inputVals) + replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal)); +} + +/// Erases a loop nest, including all its nested operations. +static void eraseLoopNest(AffineForOp forOp) { + LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n"); + forOp.erase(); +} + +/// Erases the scalar loop nest after its successful vectorization. +void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) { + LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n"); + eraseLoopNest(rootLoop); } // Apply 'map' with 'mapOperands' returning resulting values in 'results'. static void computeMemoryOpIndices(Operation *op, AffineMap map, ValueRange mapOperands, + VectorizationState &state, SmallVectorImpl &results) { - OpBuilder builder(op); for (auto resultExpr : map.getResults()) { auto singleResMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); - auto afOp = - builder.create(op->getLoc(), singleResMap, mapOperands); + auto afOp = state.builder.create(op->getLoc(), singleResMap, + mapOperands); results.push_back(afOp); } } -////// TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. //// - -/// Handles the vectorization of load and store MLIR operations. -/// -/// AffineLoadOp operations are the roots of the vectorizeNonTerminals call. -/// They are vectorized immediately. The resulting vector.transfer_read is -/// immediately registered to replace all uses of the AffineLoadOp in this -/// pattern's scope. -/// -/// AffineStoreOp are the terminals of the vectorizeNonTerminals call. They -/// need to be vectorized late once all the use-def chains have been traversed. -/// Additionally, they may have ssa-values operands which come from outside the -/// scope of the current pattern. -/// Such special cases force us to delay the vectorization of the stores until -/// the last step. Here we merely register the store operation. -template -static LogicalResult vectorizeRootOrTerminal(Value iv, - LoadOrStoreOpPointer memoryOp, - VectorizationState *state) { - auto memRefType = memoryOp.getMemRef().getType().template cast(); - - auto elementType = memRefType.getElementType(); - // TODO: ponder whether we want to further vectorize a vector value. - assert(VectorType::isValidElementType(elementType) && - "Not a valid vector element type"); - auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); - - // Materialize a MemRef with 1 vector. - auto *opInst = memoryOp.getOperation(); - // For now, vector.transfers must be aligned, operate only on indices with an - // identity subset of AffineMap and do not change layout. - // TODO: increase the expressiveness power of vector.transfer operations - // as needed by various targets. - if (auto load = dyn_cast(opInst)) { - OpBuilder b(opInst); - ValueRange mapOperands = load.getMapOperands(); - SmallVector indices; - indices.reserve(load.getMemRefType().getRank()); - if (load.getAffineMap() != - b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { - computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices); - } else { - indices.append(mapOperands.begin(), mapOperands.end()); - } - auto permutationMap = - makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); - if (!permutationMap) - return failure(); - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); - LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = b.create( - opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices, - permutationMap); - state->registerReplacement(opInst, transfer.getOperation()); - } else { - state->registerTerminal(opInst); - } - return success(); -} -/// end TODO: Hoist to a VectorizationMaterialize.cpp when appropriate. /// - -/// Coarsens the loops bounds and transforms all remaining load and store -/// operations into the appropriate vector.transfer. -static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, - VectorizationState *state) { - loop.setStep(step); - - FilterFunctionType notVectorizedThisPattern = [state](Operation &op) { - if (!matcher::isLoadOrStore(op)) { - return false; - } - return state->vectorizationMap.count(&op) == 0 && - state->vectorizedSet.count(&op) == 0 && - state->roots.count(&op) == 0 && state->terminals.count(&op) == 0; - }; - auto loadAndStores = matcher::Op(notVectorizedThisPattern); - SmallVector loadAndStoresMatches; - loadAndStores.match(loop.getOperation(), &loadAndStoresMatches); - for (auto ls : loadAndStoresMatches) { - auto *opInst = ls.getMatchedOperation(); - auto load = dyn_cast(opInst); - auto store = dyn_cast(opInst); - LLVM_DEBUG(opInst->print(dbgs())); - LogicalResult result = - load ? vectorizeRootOrTerminal(loop.getInductionVar(), load, state) - : vectorizeRootOrTerminal(loop.getInductionVar(), store, state); - if (failed(result)) { - return failure(); - } - } - return success(); -} - /// Returns a FilterFunctionType that can be used in NestedPattern to match a /// loop whose underlying load/store accesses are either invariant or all // varying along the `fastestVaryingMemRefDimension`. @@ -846,68 +804,6 @@ }; } -/// Apply vectorization of `loop` according to `state`. `loops` are processed in -/// inner-to-outer order to ensure that all the children loops have already been -/// vectorized before vectorizing the parent loop. -static LogicalResult -vectorizeLoopsAndLoads(std::vector> &loops, - VectorizationState *state) { - // Vectorize loops in inner-to-outer order. If any children fails, the parent - // will fail too. - for (auto &loopsInLevel : llvm::reverse(loops)) { - for (AffineForOp loop : loopsInLevel) { - // 1. This loop may have been omitted from vectorization for various - // reasons (e.g. due to the performance model or pattern depth > vector - // size). - auto it = state->strategy->loopToVectorDim.find(loop.getOperation()); - if (it == state->strategy->loopToVectorDim.end()) - continue; - - // 2. Actual inner-to-outer transformation. - auto vectorDim = it->second; - assert(vectorDim < state->strategy->vectorSizes.size() && - "vector dim overflow"); - // a. get actual vector size - auto vectorSize = state->strategy->vectorSizes[vectorDim]; - // b. loop transformation for early vectorization is still subject to - // exploratory tradeoffs (see top of the file). Apply coarsening, - // i.e.: - // | ub -> ub - // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForOp by " << vectorSize - << " : \n" - << loop); - if (failed( - vectorizeAffineForOp(loop, loop.getStep() * vectorSize, state))) - return failure(); - } // end for. - } - - return success(); -} - -/// Tries to transform a scalar constant into a vector splat of that constant. -/// Returns the vectorized splat operation if the constant is a valid vector -/// element type. -/// If `type` is not a valid vector type or if the scalar constant is not a -/// valid vector element type, returns nullptr. -static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) { - if (!type || !type.isa() || - !VectorType::isValidElementType(constant.getType())) { - return nullptr; - } - OpBuilder b(op); - Location loc = op->getLoc(); - auto vectorType = type.cast(); - auto attr = DenseElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpInst = constant.getOperation(); - - OperationState state(loc, constantOpInst->getName().getStringRef(), {}, - {vectorType}, {b.getNamedAttr("value", attr)}); - - return b.createOperation(state)->getResult(0); -} - /// Returns the vector type resulting from applying the provided vectorization /// strategy on the scalar type. static VectorType getVectorType(Type scalarTy, @@ -916,6 +812,24 @@ return VectorType::get(strategy->vectorSizes, scalarTy); } +/// Tries to transform a scalar constant into a vector constant. Returns the +/// vector constant if the scalar type is valid vector element type. Returns +/// nullptr, otherwise. +static ConstantOp vectorizeConstant(ConstantOp constOp, + VectorizationState &state) { + Type scalarTy = constOp.getType(); + if (!VectorType::isValidElementType(scalarTy)) + return nullptr; + + auto vecTy = getVectorType(scalarTy, state.strategy); + auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); + auto newConstOp = state.builder.create(constOp.getLoc(), vecAttr); + + // Register vector replacement for future uses in the scope. + state.registerOpVectorReplacement(constOp, newConstOp); + return newConstOp; +} + /// Returns true if the provided value is vector uniform given the vectorization /// strategy. // TODO: For now, only values that are invariants to all the loops in the @@ -932,32 +846,27 @@ /// Generates a broadcast op for the provided uniform value using the /// vectorization strategy in 'state'. -static Value vectorizeUniform(Value value, VectorizationState *state) { - OpBuilder builder(value.getContext()); - builder.setInsertionPointAfterValue(value); - - auto vectorTy = getVectorType(value.getType(), state->strategy); - auto bcast = builder.create(value.getLoc(), vectorTy, value); - - // Add broadcast to the replacement map to reuse it for other uses. - state->replacementMap[value] = bcast; - return bcast; +static Operation *vectorizeUniform(Value uniformVal, + VectorizationState &state) { + OpBuilder::InsertionGuard guard(state.builder); + state.builder.setInsertionPointAfterValue(uniformVal); + + auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); + auto bcastOp = state.builder.create(uniformVal.getLoc(), + vectorTy, uniformVal); + state.registerValueVectorReplacement(uniformVal, bcastOp); + return bcastOp; } -/// Tries to vectorize a given operand `op` of Operation `op` during -/// def-chain propagation or during terminal vectorization, by applying the -/// following logic: -/// 1. if the defining operation is part of the vectorizedSet (i.e. vectorized -/// useby -def propagation), `op` is already in the proper vector form; -/// 2. otherwise, the `op` may be in some other vector form that fails to -/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate -/// failure; -/// 3. if the `op` is a constant, returns the vectorized form of the constant; -/// 4. if the `op` is uniform, returns a vector broadcast of the `op`; -/// 5. non-constant scalars are currently non-vectorizable, in particular to -/// guard against vectorizing an index which may be loop-variant and needs -/// special handling. -/// +/// Tries to vectorize a given `operand` by applying the following logic: +/// 1. if the defining operation has been already vectorized, `operand` is +/// already in the proper vector form; +/// 2. if the `operand` is a constant, returns the vectorized form of the +/// constant; +/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`; +/// 4. otherwise, the vectorization of `operand` is not supported. +/// Newly created vector operations are registered in `state` as replacement +/// for their scalar counterparts. /// In particular this logic captures some of the use cases where definitions /// that are not scoped under the current pattern are needed to vectorize. /// One such example is top level function constants that need to be splatted. @@ -966,112 +875,213 @@ /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO: handle more complex cases. -static Value vectorizeOperand(Value operand, Operation *op, - VectorizationState *state) { - LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: " << operand); - // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(operand.getDefiningOp()) > 0) { - LLVM_DEBUG(dbgs() << " -> already vector operand"); - return operand; +static Value vectorizeOperand(Value operand, VectorizationState &state) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand); + // If this value is already vectorized, we are done. + if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) { + LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl); + return vecRepl; } - // 1.b. Delayed on-demand replacement of a use. - // Note that we cannot just call replaceAllUsesWith because it may result - // in ops with mixed types, for ops whose operands have not all yet - // been vectorized. This would be invalid IR. - auto it = state->replacementMap.find(operand); - if (it != state->replacementMap.end()) { - auto res = it->second; - LLVM_DEBUG(dbgs() << "-> delayed replacement by: " << res); - return res; + + // An vector operand that is not in the replacement map should never reach + // this point. Reaching this point could mean that the code was already + // vectorized and we shouldn't try to vectorize already vectorized code. + assert(!operand.getType().isa() && + "Vector op not found in replacement map"); + + // Vectorize constant. + if (auto constOp = operand.getDefiningOp()) { + ConstantOp vecConstant = vectorizeConstant(constOp, state); + LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant); + return vecConstant.getResult(); } - // 2. TODO: broadcast needed. - if (operand.getType().isa()) { - LLVM_DEBUG(dbgs() << "-> non-vectorizable"); - return nullptr; + + // Vectorize uniform values. + if (isUniformDefinition(operand, state.strategy)) { + Operation *vecUniform = vectorizeUniform(operand, state); + LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform); + return vecUniform->getResult(0); } - // 3. vectorize constant. - if (auto constant = operand.getDefiningOp()) - return vectorizeConstant(op, constant, - getVectorType(operand.getType(), state->strategy)); - // 4. Uniform values. - if (isUniformDefinition(operand, state->strategy)) - return vectorizeUniform(operand, state); + // Check for unsupported block argument scenarios. A supported block argument + // should have been vectorized already. + if (!operand.getDefiningOp()) + LLVM_DEBUG(dbgs() << "-> unsupported block argument\n"); + else + // Generic unsupported case. + LLVM_DEBUG(dbgs() << "-> non-vectorizable\n"); - // 5. currently non-vectorizable. - LLVM_DEBUG(dbgs() << "-> non-vectorizable: " << operand); return nullptr; } -/// Encodes Operation-specific behavior for vectorization. In general we assume -/// that all operands of an op must be vectorized but this is not always true. -/// In the future, it would be nice to have a trait that describes how a -/// particular operation vectorizes. For now we implement the case distinction -/// here. -/// Returns a vectorized form of an operation or nullptr if vectorization fails. -// TODO: consider adding a trait to Op to describe how it gets vectorized. -// Maybe some Ops are not vectorizable or require some tricky logic, we cannot -// do one-off logic here; ideally it would be TableGen'd. -static Operation *vectorizeOneOperation(Operation *opInst, - VectorizationState *state) { - // Sanity checks. - assert(!isa(opInst) && - "all loads must have already been fully vectorized independently"); - assert(!isa(opInst) && - "vector.transfer_read cannot be further vectorized"); - assert(!isa(opInst) && - "vector.transfer_write cannot be further vectorized"); +/// Vectorizes an affine load with the vectorization strategy in 'state' by +/// generating a 'vector.transfer_read' op with the proper permutation map +/// inferred from the indices of the load. The new 'vector.transfer_read' is +/// registered as replacement of the scalar load. Returns the newly created +/// 'vector.transfer_read' if vectorization was successful. Returns nullptr, +/// otherwise. +static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, + VectorizationState &state) { + MemRefType memRefType = loadOp.getMemRefType(); + Type elementType = memRefType.getElementType(); + auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType); + + // Replace map operands with operands from the vector loop nest. + SmallVector mapOperands; + state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands); + + // Compute indices for the transfer op. AffineApplyOp's may be generated. + SmallVector indices; + indices.reserve(memRefType.getRank()); + if (loadOp.getAffineMap() != + state.builder.getMultiDimIdentityMap(memRefType.getRank())) + computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state, + indices); + else + indices.append(mapOperands.begin(), mapOperands.end()); + + // Compute permutation map using the information of new vector loops. + auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), + indices, state.vecLoopToVecDim); + if (!permutationMap) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n"); + return nullptr; + } + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); - if (auto store = dyn_cast(opInst)) { - OpBuilder b(opInst); - auto memRef = store.getMemRef(); - auto value = store.getValueToStore(); - auto vectorValue = vectorizeOperand(value, opInst, state); - if (!vectorValue) - return nullptr; + auto transfer = state.builder.create( + loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap); - ValueRange mapOperands = store.getMapOperands(); - SmallVector indices; - indices.reserve(store.getMemRefType().getRank()); - if (store.getAffineMap() != - b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { - computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands, - indices); - } else { - indices.append(mapOperands.begin(), mapOperands.end()); - } + // Register replacement for future uses in the scope. + state.registerOpVectorReplacement(loadOp, transfer); + return transfer; +} - auto permutationMap = - makePermutationMap(opInst, indices, state->strategy->loopToVectorDim); - if (!permutationMap) - return nullptr; - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); - LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = b.create( - opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = transfer.getOperation(); - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); - // "Terminals" (i.e. AffineStoreOps) are erased on the spot. - opInst->erase(); - return res; - } - if (opInst->getNumRegions() != 0) +/// Vectorizes an affine store with the vectorization strategy in 'state' by +/// generating a 'vector.transfer_write' op with the proper permutation map +/// inferred from the indices of the store. The new 'vector.transfer_store' is +/// registered as replacement of the scalar load. Returns the newly created +/// 'vector.transfer_write' if vectorization was successful. Returns nullptr, +/// otherwise. +static Operation *vectorizeAffineStore(AffineStoreOp storeOp, + VectorizationState &state) { + MemRefType memRefType = storeOp.getMemRefType(); + Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state); + if (!vectorValue) + return nullptr; + + // Replace map operands with operands from the vector loop nest. + SmallVector mapOperands; + state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands); + + // Compute indices for the transfer op. AffineApplyOp's may be generated. + SmallVector indices; + indices.reserve(memRefType.getRank()); + if (storeOp.getAffineMap() != + state.builder.getMultiDimIdentityMap(memRefType.getRank())) + computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state, + indices); + else + indices.append(mapOperands.begin(), mapOperands.end()); + + // Compute permutation map using the information of new vector loops. + auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), + indices, state.vecLoopToVecDim); + if (!permutationMap) + return nullptr; + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); + + auto transfer = state.builder.create( + storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, + permutationMap); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer); + + // Register replacement for future uses in the scope. + state.registerOpVectorReplacement(storeOp, transfer); + return transfer; +} + +/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is +/// created and registered as replacement for the scalar loop. The builder's +/// insertion point is set to the new loop's body so that subsequent vectorized +/// operations are inserted into the new loop. If the loop is a vector +/// dimension, the step of the newly created loop will reflect the vectorization +/// factor used to vectorized that dimension. +// TODO: Add support for 'iter_args'. Related operands and results will be +// vectorized at this point. +static Operation *vectorizeAffineForOp(AffineForOp forOp, + VectorizationState &state) { + // 'iter_args' not supported yet. + if (forOp.getNumIterOperands() > 0) return nullptr; + // If we are vectorizing a vector dimension, compute a new step for the new + // vectorized loop using the vectorization factor for the vector dimension. + // Otherwise, propagate the step of the scalar loop. + const VectorizationStrategy &strategy = *state.strategy; + auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); + bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); + unsigned newStep; + if (isLoopVecDim) { + unsigned vectorDim = loopToVecDimIt->second; + assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow"); + int64_t forOpVecFactor = strategy.vectorSizes[vectorDim]; + newStep = forOp.getStep() * forOpVecFactor; + } else { + newStep = forOp.getStep(); + } + + auto vecForOp = state.builder.create( + forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), + forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, + forOp.getIterOperands(), + /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { + // Make sure we don't create a default terminator in the loop body as + // the proper terminator will be added during vectorization. + return; + }); + + // Register loop-related replacements: + // 1) The new vectorized loop is registered as vector replacement of the + // scalar loop. + // TODO: Support reductions along the vector dimension. + // 2) The new iv of the vectorized loop is registered as scalar replacement + // since a scalar copy of the iv will prevail in the vectorized loop. + // TODO: A vector replacement will also be added in the future when + // vectorization of linear ops is supported. + // 3) TODO: Support 'iter_args' along non-vector dimensions. + state.registerOpVectorReplacement(forOp, vecForOp); + state.registerValueScalarReplacement(forOp.getInductionVar(), + vecForOp.getInductionVar()); + // Map the new vectorized loop to its vector dimension. + if (isLoopVecDim) + state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second; + + // Change insertion point so that upcoming vectorized instructions are + // inserted into the vectorized loop's body. + state.builder.setInsertionPointToStart(vecForOp.getBody()); + return vecForOp; +} + +/// Vectorizes arbitrary operation by plain widening. We apply generic type +/// widening of all its results and retrieve the vector counterparts for all its +/// operands. +static Operation *widenOp(Operation *op, VectorizationState &state) { SmallVector vectorTypes; - for (auto v : opInst->getResults()) { + for (Value result : op->getResults()) vectorTypes.push_back( - VectorType::get(state->strategy->vectorSizes, v.getType())); - } + VectorType::get(state.strategy->vectorSizes, result.getType())); + SmallVector vectorOperands; - for (auto v : opInst->getOperands()) { - vectorOperands.push_back(vectorizeOperand(v, opInst, state)); - } - // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(vectorOperands, [](Value op) { return op; }); - if (!success) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); - return nullptr; + for (Value operand : op->getOperands()) { + Value vecOperand = vectorizeOperand(operand, state); + if (!vecOperand) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n"); + return nullptr; + } + vectorOperands.push_back(vecOperand); } // Create a clone of the op with the proper operands and return types. @@ -1079,59 +1089,64 @@ // name that works both in scalar mode and vector mode. // TODO: Is it worth considering an Operation.clone operation which // changes the type so we can promote an Operation with less boilerplate? - OpBuilder b(opInst); - OperationState newOp(opInst->getLoc(), opInst->getName().getStringRef(), - vectorOperands, vectorTypes, opInst->getAttrs(), - /*successors=*/{}, /*regions=*/{}); - return b.createOperation(newOp); + OperationState vecOpState(op->getLoc(), op->getName().getStringRef(), + vectorOperands, vectorTypes, op->getAttrs(), + /*successors=*/{}, /*regions=*/{}); + Operation *vecOp = state.builder.createOperation(vecOpState); + state.registerOpVectorReplacement(op, vecOp); + return vecOp; } -/// Iterates over the forward slice from the loads in the vectorization pattern -/// and rewrites them using their vectorized counterpart by: -/// 1. Create the forward slice starting from the loads in the vectorization -/// pattern. -/// 2. Topologically sorts the forward slice. -/// 3. For each operation in the slice, create the vector form of this -/// operation, replacing each operand by a replacement operands retrieved from -/// replacementMap. If any such replacement is missing, vectorization fails. -static LogicalResult vectorizeNonTerminals(VectorizationState *state) { - // 1. create initial worklist with the uses of the roots. - SetVector worklist; - // Note: state->roots have already been vectorized and must not be vectorized - // again. This fits `getForwardSlice` which does not insert `op` in the - // result. - // Note: we have to exclude terminals because some of their defs may not be - // nested under the vectorization pattern (e.g. constants defined in an - // encompassing scope). - // TODO: Use a backward slice for terminals, avoid special casing and - // merge implementations. - for (auto *op : state->roots) { - getForwardSlice(op, &worklist, [state](Operation *op) { - return state->terminals.count(op) == 0; // propagate if not terminal - }); - } - // We merged multiple slices, topological order may not hold anymore. - worklist = topologicalSort(worklist); - - for (unsigned i = 0; i < worklist.size(); ++i) { - auto *op = worklist[i]; - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); - LLVM_DEBUG(op->print(dbgs())); - - // Create vector form of the operation. - // Insert it just before op, on success register op as replaced. - auto *vectorizedInst = vectorizeOneOperation(op, state); - if (!vectorizedInst) { - return failure(); - } +/// Vectorizes a yield operation by widening its types. The builder's insertion +/// point is set after the vectorized parent op to continue vectorizing the +/// operations after the parent op. +static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, + VectorizationState &state) { + // 'iter_args' not supported yet. + if (yieldOp.getNumOperands() > 0) + return nullptr; - // 3. Register replacement for future uses in the scope. - // Note that we cannot just call replaceAllUsesWith because it may - // result in ops with mixed types, for ops whose operands have not all - // yet been vectorized. This would be invalid IR. - state->registerReplacement(op, vectorizedInst); - } - return success(); + // Vectorize the yield op and change the insertion point right after the new + // parent op. + Operation *newYieldOp = widenOp(yieldOp, state); + Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp(); + state.builder.setInsertionPointAfter(newParentOp); + return newYieldOp; +} + +/// Encodes Operation-specific behavior for vectorization. In general we +/// assume that all operands of an op must be vectorized but this is not +/// always true. In the future, it would be nice to have a trait that +/// describes how a particular operation vectorizes. For now we implement the +/// case distinction here. Returns a vectorized form of an operation or +/// nullptr if vectorization fails. +// TODO: consider adding a trait to Op to describe how it gets vectorized. +// Maybe some Ops are not vectorizable or require some tricky logic, we cannot +// do one-off logic here; ideally it would be TableGen'd. +static Operation *vectorizeOneOperation(Operation *op, + VectorizationState &state) { + // Sanity checks. + assert(!isa(op) && + "vector.transfer_read cannot be further vectorized"); + assert(!isa(op) && + "vector.transfer_write cannot be further vectorized"); + + if (auto loadOp = dyn_cast(op)) + return vectorizeAffineLoad(loadOp, state); + if (auto storeOp = dyn_cast(op)) + return vectorizeAffineStore(storeOp, state); + if (auto forOp = dyn_cast(op)) + return vectorizeAffineForOp(forOp, state); + if (auto yieldOp = dyn_cast(op)) + return vectorizeAffineYieldOp(yieldOp, state); + if (auto constant = dyn_cast(op)) + return vectorizeConstant(constant, state); + + // Other ops with regions are not supported. + if (op->getNumRegions() != 0) + return nullptr; + + return widenOp(op, state); } /// Recursive implementation to convert all the nested loops in 'match' to a 2D @@ -1171,10 +1186,9 @@ const VectorizationStrategy &strategy) { assert(loops[0].size() == 1 && "Expected single root loop"); AffineForOp rootLoop = loops[0][0]; - OperationFolder folder(rootLoop.getContext()); - VectorizationState state; + VectorizationState state(rootLoop.getContext()); + state.builder.setInsertionPointAfter(rootLoop); state.strategy = &strategy; - state.folder = &folder; // Since patterns are recursive, they can very well intersect. // Since we do not want a fully greedy strategy in general, we decouple @@ -1188,70 +1202,48 @@ return failure(); } - /// Sets up error handling for this root loop. This is how the root match - /// maintains a clone for handling failure and restores the proper state via - /// RAII. - auto *loopInst = rootLoop.getOperation(); - OpBuilder builder(loopInst); - auto clonedLoop = cast(builder.clone(*loopInst)); - struct Guard { - LogicalResult failure() { - loop.getInductionVar().replaceAllUsesWith(clonedLoop.getInductionVar()); - loop.erase(); - return mlir::failure(); - } - LogicalResult success() { - clonedLoop.erase(); - return mlir::success(); - } - AffineForOp loop; - AffineForOp clonedLoop; - } guard{rootLoop, clonedLoop}; - ////////////////////////////////////////////////////////////////////////////// - // Start vectorizing. - // From now on, any error triggers the scope guard above. + // Vectorize the scalar loop nest following a topological order. A new vector + // loop nest with the vectorized operations is created along the process. If + // vectorization succeeds, the scalar loop nest is erased. If vectorization + // fails, the vector loop nest is erased and the scalar loop nest is not + // modified. ////////////////////////////////////////////////////////////////////////////// - // 1. Vectorize all the loop candidates, in inner-to-outer order. - // This also vectorizes the roots (AffineLoadOp) as well as registers the - // terminals (AffineStoreOp) for post-processing vectorization (we need to - // wait for all use-def chains into them to be vectorized first). - if (failed(vectorizeLoopsAndLoads(loops, &state))) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root vectorizeLoop"); - return guard.failure(); - } - // 2. Vectorize operations reached by use-def chains from root except the - // terminals (store operations) that need to be post-processed separately. - // TODO: add more as we expand. - if (failed(vectorizeNonTerminals(&state))) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeNonTerminals"); - return guard.failure(); - } + auto opVecResult = rootLoop.walk([&](Operation *op) { + LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op); + Operation *vectorOp = vectorizeOneOperation(op, state); + if (!vectorOp) + return WalkResult::interrupt(); - // 3. Post-process terminals. - // Note: we have to post-process terminals because some of their defs may not - // be nested under the vectorization pattern (e.g. constants defined in an - // encompassing scope). - // TODO: Use a backward slice for terminals, avoid special casing and - // merge implementations. - for (auto *op : state.terminals) { - if (!vectorizeOneOperation(op, &state)) { // nullptr == failure - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminals"); - return guard.failure(); - } + return WalkResult::advance(); + }); + + if (opVecResult.wasInterrupted()) { + LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: " + << rootLoop << "\n"); + // Erase vector loop nest if it was created. + auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop); + if (vecRootLoopIt != state.opVectorReplacement.end()) + eraseLoopNest(cast(vecRootLoopIt->second)); + + return failure(); } - // 4. Finish this vectorization pattern. + assert(state.opVectorReplacement.count(rootLoop) == 1 && + "Expected vector replacement for loop nest"); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); - state.finishVectorizationPattern(); - return guard.success(); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n" + << *state.opVectorReplacement[rootLoop]); + + // Finish this vectorization pattern. + state.finishVectorizationPattern(rootLoop); + return success(); } -/// Vectorization is a recursive procedure where anything below can fail. The -/// root match thus needs to maintain a clone for handling failure. Each root -/// may succeed independently but will otherwise clean after itself if anything -/// below it fails. +/// Extracts the matched loops and vectorizes them following a topological +/// order. A new vector loop nest will be created if vectorization succeeds. The +/// original loop nest won't be modified in any case. static LogicalResult vectorizeRootMatch(NestedMatch m, const VectorizationStrategy &strategy) { std::vector> loopsToVectorize; @@ -1272,7 +1264,7 @@ LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n"); - LLVM_DEBUG(parentOp->print(dbgs())); + LLVM_DEBUG(dbgs() << *parentOp << "\n"); unsigned patternDepth = pat.getDepth(); diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -211,29 +211,29 @@ /// TODO: could also be implemented as a collect parents followed by a /// filter and made available outside this file. template -static SetVector getParentsOfType(Operation *op) { +static SetVector getParentsOfType(Block *block) { SetVector res; - auto *current = op; - while (auto *parent = current->getParentOp()) { - if (auto typedParent = dyn_cast(parent)) { - assert(res.count(parent) == 0 && "Already inserted"); - res.insert(parent); + auto *current = block->getParentOp(); + while (current) { + if (auto typedParent = dyn_cast(current)) { + assert(res.count(current) == 0 && "Already inserted"); + res.insert(current); } - current = parent; + current = current->getParentOp(); } return res; } /// Returns the enclosing AffineForOp, from closest to farthest. -static SetVector getEnclosingforOps(Operation *op) { - return getParentsOfType(op); +static SetVector getEnclosingforOps(Block *block) { + return getParentsOfType(block); } AffineMap mlir::makePermutationMap( - Operation *op, ArrayRef indices, + Block *insertPoint, ArrayRef indices, const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; - auto enclosingLoops = getEnclosingforOps(op); + auto enclosingLoops = getEnclosingforOps(insertPoint); for (auto *forInst : enclosingLoops) { auto it = loopToVectorDim.find(forInst); if (it != loopToVectorDim.end()) { @@ -243,6 +243,12 @@ return ::makePermutationMap(indices, enclosingLoopToVectorDim); } +AffineMap mlir::makePermutationMap( + Operation *op, ArrayRef indices, + const DenseMap &loopToVectorDim) { + return makePermutationMap(op->getBlock(), indices, loopToVectorDim); +} + AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType) { int64_t elementVectorRank = 0; diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -1,16 +1,8 @@ -// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" | FileCheck %s +// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0" -split-input-file | FileCheck %s -// Permutation maps used in vectorization. -// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)> // CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)> -#map0 = affine_map<(d0) -> (d0)> -#mapadd1 = affine_map<(d0) -> (d0 + 1)> -#mapadd2 = affine_map<(d0) -> (d0 + 2)> -#mapadd3 = affine_map<(d0) -> (d0 + 3)> -#set0 = affine_set<(i) : (i >= 0)> - -// Maps introduced to vectorize fastest varying memory index. // CHECK-LABEL: func @vec1d_1 func @vec1d_1(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -37,6 +29,8 @@ return } +// ----- + // CHECK-LABEL: func @vec1d_2 func @vec1d_2(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -61,6 +55,8 @@ return } +// ----- + // CHECK-LABEL: func @vec1d_3 func @vec1d_3(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -90,6 +86,8 @@ return } +// ----- + // CHECK-LABEL: func @vector_add_2d func @vector_add_2d(%M : index, %N : index) -> f32 { %A = alloc (%M, %N) : memref @@ -142,6 +140,8 @@ return %res : f32 } +// ----- + // CHECK-LABEL: func @vec_rejected_1 func @vec_rejected_1(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -164,6 +164,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_2 func @vec_rejected_2(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -186,6 +188,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_3 func @vec_rejected_3(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -213,6 +217,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_4 func @vec_rejected_4(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -238,6 +244,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_5 func @vec_rejected_5(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -264,6 +272,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_6 func @vec_rejected_6(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -292,6 +302,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_7 func @vec_rejected_7(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -315,6 +327,11 @@ return } +// ----- + +// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)> + // CHECK-LABEL: func @vec_rejected_8 func @vec_rejected_8(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -344,6 +361,11 @@ return } +// ----- + +// CHECK-DAG: #[[$map_id1:map[0-9]+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$map_proj_d0d1_0:map[0-9]+]] = affine_map<(d0, d1) -> (0)> + // CHECK-LABEL: func @vec_rejected_9 func @vec_rejected_9(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -373,6 +395,10 @@ return } +// ----- + +#set0 = affine_set<(i) : (i >= 0)> + // CHECK-LABEL: func @vec_rejected_10 func @vec_rejected_10(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -397,6 +423,8 @@ return } +// ----- + // CHECK-LABEL: func @vec_rejected_11 func @vec_rejected_11(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index @@ -424,7 +452,9 @@ return } -// This should not vectorize due to the sequential dependence in the scf. +// ----- + +// This should not vectorize due to the sequential dependence in the loop. // CHECK-LABEL: @vec_rejected_sequential func @vec_rejected_sequential(%A : memref) { %c0 = constant 0 : index @@ -437,3 +467,66 @@ } return } + +// ----- + +// CHECK-LABEL: @vec_no_load_store_ops +func @vec_no_load_store_ops(%a: f32, %b: f32) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 128 { + %add = addf %a, %b : f32 + } + // CHECK-DAG: %[[bc1:.*]] = vector.broadcast + // CHECK-DAG: %[[bc0:.*]] = vector.broadcast + // CHECK: affine.for %{{.*}} = 0 to 128 step + // CHECK-NEXT: [[add:.*]] addf %[[bc0]], %[[bc1]] + + return +} + +// ----- + +// This should not be vectorized due to the unsupported block argument (%i). +// Support for operands with linear evolution is needed. +// CHECK-LABEL: @vec_rejected_unsupported_block_arg +func @vec_rejected_unsupported_block_arg(%A : memref<512xi32>) { + affine.for %i = 0 to 512 { + // CHECK-NOT: vector + %idx = std.index_cast %i : index to i32 + affine.store %idx, %A[%i] : memref<512xi32> + } + return +} + +// ----- + +// CHECK-LABEL: @vec_rejected_unsupported_reduction +func @vec_rejected_unsupported_reduction(%in: memref<128x256xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + // CHECK-NOT: vector + %final_red = affine.for %j = 0 to 128 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%j, %i] : memref<128x256xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// ----- + +// CHECK-LABEL: @vec_rejected_unsupported_last_value +func @vec_rejected_unsupported_last_value(%in: memref<128x256xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + // CHECK-NOT: vector + %last_val = affine.for %j = 0 to 128 iter_args(%last_iter = %cst) -> (f32) { + %ld = affine.load %in[%j, %i] : memref<128x256xf32> + affine.yield %ld : f32 + } + affine.store %last_val, %out[%i] : memref<256xf32> + } + return +}