diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -591,6 +591,35 @@ return result; }] >, + InterfaceMethod< + /*desc=*/[{ + Given a dimension of the iteration space of a Linalg operation, finds an + operand in the operation that is defined on such dimension. Returns + whether such operand was found or not. If found, also returns the + operand value and the dimension position within the operand. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"mapIterationSpaceDimToOperandDim", + /*args=*/(ins "unsigned":$dimPos, + "::mlir::Value &":$operand, + "unsigned &":$operandDimPos), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Retrieve the operand and its dimension position from the first + // operand with a permutation map that is defined on such dimension. + for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) { + if (idxMap.isProjectedPermutation()) { + if (auto mayOperandDim = idxMap.getResultPosition(dimPos)) { + operand = $_op->getOperand(i); + operandDimPos = *mayOperandDim; + return success(); + } + } + } + + return failure(); + }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -337,8 +337,14 @@ FailureOr promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options); -/// Emit a suitable vector form for a Linalg op with fully static shape. -LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp); +/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` +/// are used to vectorize this operation. `inputVectorSizes` must match the rank +/// of the iteration space of the operation and the sizes must be smaller or +/// equal than their counterpart interation space sizes, if static. +/// `inputVectorShapes` also allows the vectorization of operations with dynamic +/// shapes. +LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes = {}); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); @@ -364,7 +370,9 @@ LinalgPromotionOptions options); /// Return success if the operation can be vectorized. -LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp); +LogicalResult +vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + ArrayRef inputVectorSizes = {}); //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td --- a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td @@ -54,18 +54,14 @@ return false; }]>, InterfaceMethod< - /*desc=*/"Returns the mask type expected by this operation. It requires " - "the operation to be vectorized.", - /*retTy=*/"mlir::VectorType", + /*desc=*/"Returns the mask type expected by this operation. Mostly used" + " for verification purposes. It requires the operation to be " + "vectorized.", + /*retTy=*/"mlir::Type", /*methodName=*/"getExpectedMaskType", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Default implementation is only aimed for operations that implement the - // `getVectorType()` method. - return $_op.getVectorType().cloneWith(/*shape=*/llvm::None, - IntegerType::get($_op.getContext(), /*width=*/1)); - }]>, + /*defaultImplementation=*/"">, ]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -22,6 +22,12 @@ /// Creates an instance of the `vector.mask` lowering pass. std::unique_ptr createLowerVectorMaskPass(); +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -171,6 +171,12 @@ /// expression in results. Optional getResultPosition(unsigned input) const; + /// Extracts the permuted position where the given input index resides. + /// Returns `llvm::None` if the input index is projected. Asserts on + /// non-projected permutation maps. + Optional + getPermutedPositionOfProjectedPermutation(unsigned input) const; + /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { return llvm::any_of(getResults(), [&](AffineExpr e) { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1753,6 +1753,7 @@ LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); + // TODO: Pass vector sizes for masked dims. return vectorize(rewriter, linalgOp); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -11,25 +11,20 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -65,6 +60,250 @@ return res; } +/// Contains the vectorization state and related methods used across the +/// vectorization process of a given operation. +struct VectorizationState { + /// Initializes the vectorization state, including the computation of the + /// canonical vector shape for vectorization. + LogicalResult initState(OpBuilder &builder, LinalgOp linalgOp, + ArrayRef inputVectorSizes); + + /// Returns the canonical vector shape used to vectorize the iteration space. + ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } + + /// Masks an operation with the canonical vector mask if the operation needs + /// masking. Returns the masked operation or the original operation if masking + /// is not needed. If provided, the canonical mask for this operation is + /// permuted using `maybeMaskingMap`. + Operation *maskOperation(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap = llvm::None); + +private: + /// Initializes the iteration space static sizes using the Linalg op + /// information. This may become more complicated in the future. + void initIterSpaceStaticSizes(LinalgOp linalgOp) { + iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); + } + + /// Generates 'tensor.dim' operations for all the dynamic dimensions of the + /// iteration space to be vectorized and store them in + /// `iterSpaceDynamicSizes`. + LogicalResult precomputeIterSpaceDynamicSizes(OpBuilder &builder, + LinalgOp linalgOp); + + /// Create or retrieve an existing mask value to mask `opToMask` in the + /// canonical vector iteration space. If `maybeMaskingMap` the mask is + /// permuted using that permutation map. If a new mask is created, it will be + /// cached for future users. + Value getOrCreateMaskFor(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap); + + // Holds the compile-time static sizes of the iteration space to vectorize. + // Dynamic dimensions are represented using ShapedType::kDynamicSize. + SmallVector iterSpaceStaticSizes; + + /// Holds the runtime sizes of the iteration spaces to vectorize. Static + /// dimensions are represented with a empty value. + SmallVector iterSpaceDynamicSizes; + + /// Holds the canonical vector shape used to vectorize the iteration space. + SmallVector canonicalVecShape; + + /// Holds the active masks for permutations of the canonical vector iteration + /// space. + DenseMap activeMaskCache; +}; + +/// Generates 'tensor.dim' operations for all the dynamic dimensions of the +/// iteration space to be vectorized and store them in +/// `iterSpaceDynamicSizes`. +LogicalResult +VectorizationState::precomputeIterSpaceDynamicSizes(OpBuilder &builder, + LinalgOp linalgOp) { + // Problem with 0-d vectors. + //assert(!canonicalVecShape.empty() && "Uninitialized canonical vector shape"); + for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { + if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { + // Add a empty value for static dimensions. + iterSpaceDynamicSizes.push_back(Value()); + continue; + } + + // Find an operand defined on this dimension of the iteration space to + // extract the runtime dimension size. + Value operand; + unsigned operandDimPos; + if (succeeded(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand, + operandDimPos))) + return failure(); + + Value dynamicDim = linalgOp.hasTensorSemantics() + ? (Value)builder.create( + linalgOp.getLoc(), operand, operandDimPos) + : (Value)builder.create( + linalgOp.getLoc(), operand, operandDimPos); + iterSpaceDynamicSizes.push_back(dynamicDim); + } + + return success(); +} + +/// Initializes the vectorization state, including the computation of the +/// canonical vector shape for vectorization. +LogicalResult +VectorizationState::initState(OpBuilder &builder, LinalgOp linalgOp, + ArrayRef inputVectorSizes) { + if (!inputVectorSizes.empty()) { + // Get the canonical vector shape from the input vector sizes provided. This + // path should be taken to vectorize code with dynamic shapes and when using + // vector sizes greater than the iteration space sizes. + canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); + } else { + // Compute the canonical vector shape from the operation shape. If there are + // dynamic shapes, the operation won't be vectorized. + canonicalVecShape = linalgOp.getStaticLoopRanges(); + } + + LDBG("Canonical vector shape: "); + LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + // Initialize iteration space static sizes. + initIterSpaceStaticSizes(linalgOp); + + // Extract and register the runtime value of any potential dynamic shape + // needed to compute a mask during vectorization. + if (failed(precomputeIterSpaceDynamicSizes(builder, linalgOp))) + return failure(); + + if (ShapedType::isDynamicShape(canonicalVecShape)) + return failure(); + return success(); +} + +/// Create or retrieve an existing mask value to mask `opToMask` in the +/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted +/// using that permutation map. If a new mask is created, it will be cached for +/// future users. +Value VectorizationState::getOrCreateMaskFor( + OpBuilder &builder, Operation *opToMask, LinalgOp linalgOp, + Optional maybeMaskingMap) { + // No mask is needed if the operation is not maskable. + auto maskableOp = dyn_cast(opToMask); + if (!maskableOp) + return Value(); + + assert(!maskableOp.isMasked() && + "Masking an operation that is already masked"); + + // If no masking map was provided, use an identity map with the loop dims. + assert((!maybeMaskingMap || *maybeMaskingMap) && + "Unexpected null mask permutation map"); + AffineMap maskingMap = + maybeMaskingMap ? *maybeMaskingMap + : AffineMap::getMultiDimIdentityMap( + linalgOp.getNumLoops(), builder.getContext()); + LDBG("Masking map: " << maskingMap << "\n"); + + // Return the active mask for the masking map of this operation if it was + // already created. + if (activeMaskCache.count(maskingMap)) { + Value mask = activeMaskCache[maskingMap]; + LDBG("Reusing mask: " << mask << "\n"); + return mask; + } + + // Compute permuted projection of the iteration space to be masked and the + // corresponding mask shape. If the resulting iteration space dimensions are + // static and identical to the mask shape, masking is not needed for this + // operation. + // TODO: Improve this check. Only projected permutation indexing maps are + // supported. + SmallVector permutedStaticSizes = + applyPermutationMap(maskingMap, ArrayRef(iterSpaceStaticSizes)); + SmallVector maskShape = + applyPermutationMap(maskingMap, ArrayRef(canonicalVecShape)); + LDBG("Mask shape: "); + LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + if (permutedStaticSizes == maskShape) { + LDBG("Masking is not needed for masking map: " << maskingMap); + activeMaskCache[maskingMap] = Value(); + return Value(); + } + + // Compute the mask upper bound values by combining the permuted iteration + // space static sizes and the dynamic values. + SmallVector permutedDynamicSizes = + applyPermutationMap(maskingMap, ArrayRef(iterSpaceDynamicSizes)); + SmallVector upperBounds; + for (auto [staticBound, dynBound] : + llvm::zip(permutedStaticSizes, permutedDynamicSizes)) + upperBounds.push_back(ShapedType::isDynamic(staticBound) + ? dynBound + : builder.create( + linalgOp.getLoc(), staticBound)); + + assert(!maskShape.empty() && !upperBounds.empty() && + "Masked 0-d vectors are not supported yet"); + + // Create the mask based on the value uppermension size values. + auto maskType = VectorType::get(maskShape, builder.getI1Type()); + Value mask = builder.create(linalgOp.getLoc(), maskType, + upperBounds); + LDBG("Creating new mask: " << mask << "\n"); + activeMaskCache[maskingMap] = mask; + return mask; +} + +/// Masks an operation with the canonical vector mask if the operation needs +/// masking. Returns the masked operation or the original operation if masking +/// is not needed. If provided, the canonical mask for this operation is +/// permuted using `maybeMaskingMap`. +Operation * +VectorizationState::maskOperation(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap) { + LDBG("Trying to mask: " << *opToMask << "\n"); + + // Create or retrieve mask for this operation. + Value mask = getOrCreateMaskFor(builder, opToMask, linalgOp, maybeMaskingMap); + + if (!mask) { + LDBG("No mask required\n"); + return opToMask; + } + + // Wrap the operation with a new `vector.mask` and update D-U chain. + assert(opToMask && "Expected a valid operation to mask"); + auto opResults = opToMask->getResultTypes(); + auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) { + Block *insBlock = builder.getInsertionBlock(); + insBlock->getOperations().splice( + insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask); + builder.create(loc, opToMask->getResults()); + }; + // TODO: Allow multiple results in vector.mask. + auto maskOp = + opResults.empty() + ? builder.create(opToMask->getLoc(), mask, + createRegionMask) + : builder.create(opToMask->getLoc(), + opToMask->getResultTypes().front(), + mask, createRegionMask); + + Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); + for (auto &en : llvm::enumerate(opToMask->getResults())) + en.value().replaceAllUsesExcept(maskOp.getResult(en.index()), + maskOpTerminator); + + LDBG("Masked operation: " << *maskOp << "\n"); + return maskOp; +} + /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a /// projectedPermutation, compress the unused dimensions to serve as a /// permutation_map for a vector transfer operation. @@ -198,38 +437,34 @@ /// to all `0`; where `outputOperand` is an output operand of the LinalgOp /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand) { +static Operation *buildVectorWrite(OpBuilder &b, Value value, + OpOperand *outputOperand, + VectorizationState &state) { Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); - ArrayRef shape = linalgOp.getShape(outputOperand); - auto vectorType = VectorType::get( - shape, getElementTypeOrSelf(outputOperand->get().getType())); + AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); + auto vectorType = + VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()), + getElementTypeOrSelf(outputOperand->get().getType())); + if (vectorType.getRank() > 0) { - // 0-d case is still special: do not invert the reindexing map. - AffineMap map = - reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand)); - SmallVector transposeShape = - applyPermutationMap(inversePermutation(map), vectorType.getShape()); - assert(!transposeShape.empty() && "unexpected empty transpose shape"); - vectorType = VectorType::get(transposeShape, vectorType.getElementType()); + AffineMap writeMap = reindexIndexingMap(opOperandMap); SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); write = b.create(loc, value, outputOperand->get(), - indices, map); + indices, writeMap); } else { + // 0-d case is still special: do not invert the reindexing writeMap. if (!value.getType().isa()) value = b.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); write = b.create(loc, value, outputOperand->get(), ValueRange{}); } - LDBG("vectorized op: " << *write); - if (!write->getResults().empty()) - return write->getResult(0); - return Value(); + LDBG("vectorized op: " << *write << "\n"); + return write; } // Custom vectorization precondition function type. This is intented to be used @@ -253,20 +488,35 @@ /// CustomVectorizationHook. static VectorizationResult vectorizeLinalgYield(OpBuilder &b, Operation *op, - const BlockAndValueMapping &bvm, LinalgOp linalgOp, - SmallVectorImpl &newResults) { + const BlockAndValueMapping &bvm, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); if (!yieldOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; - for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) { + for (const auto &output : llvm::enumerate(yieldOp.getValues())) { // TODO: Scan for an opportunity for reuse. // TODO: use a map. - Value vectorValue = bvm.lookup(outputs.value()); - Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); - if (newResult) - newResults.push_back(newResult); + Value vectorValue = bvm.lookup(output.value()); + OpOperand *opOperand = linalgOp.getDpsInitOperand(output.index()); + Operation *write = buildVectorWrite( + b, vectorValue, linalgOp.getDpsInitOperand(output.index()), state); + // TODO: We mask the transfer.transfer_write here because this op is + // special-cased. A linalg.yield may produced multiple vector.transfer_write + // ops and can't be mapped using BlockAndValueMapping. + AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(opOperand); + write = state.maskOperation(b, write, linalgOp, opOperandMap); + newResults.append(write->result_begin(), write->result_end()); + + // If masked, set in-bounds to true. Masking guarantees that the access will + // be in-bounds. + if (auto maskOp = dyn_cast(write)) { + auto maskedWriteOp = + cast(maskOp.getMaskableOp()); + SmallVector inBounds(maskedWriteOp.getVectorType().getRank(), true); + maskedWriteOp.setInBoundsAttr(b.getBoolArrayAttr(inBounds)); + } } + return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } @@ -410,7 +660,7 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { - LDBG("vectorize op " << *op); + LDBG("vectorize op " << *op << "\n"); // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -506,8 +756,10 @@ /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(OpBuilder &b, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { + LDBG("Vectorizing operation as linalg generic\n"); Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -520,11 +772,6 @@ if (linalgOp.getNumDpsInits() == 0) return failure(); - // TODO: the common vector shape is equal to the static loop sizes only when - // all indexing maps are projected permutations. For convs and stencils the - // logic will need to evolve. - SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); - // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); @@ -534,35 +781,60 @@ bvm.map(bbarg, opOperand->get()); continue; } - VectorType readType; - AffineMap map; - // TODO: can we keep this simplification? - // if (linalgOp.getShape(&opOperand).empty()) { - // readType = VectorType::get({}, bbarg.getType()); - // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) { - map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); - readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + + // 3.a. Convert the indexing map for this input/output to a transfer read + // permutation map and masking map. + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + + // Remove zeros from indexing map to use it as masking map. + SmallVector zeroPos; + auto results = indexingMap.getResults(); + for (auto result : llvm::enumerate(results)) { + if (result.value().isa()) { + zeroPos.push_back(result.index()); + } + } + AffineMap maskMap = indexingMap.dropResults(zeroPos); + + AffineMap readMap; + SmallVector readVecShape; + if (linalgOp.isDpsInput(opOperand)) { + // 3.a.i. For input reads we use the canonical vector shape. + readMap = inverseAndBroadcastProjectedPermutation(indexingMap); + readVecShape = llvm::to_vector(state.getCanonicalVecShape()); } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); + // 3.a.ii. For output reads (iteration-carried dependence, e.g., + // reductions), the vector shape is computed by mapping the canonical + // vector shape to the output domain and back to the canonical domain. + readMap = inversePermutation(reindexIndexingMap(indexingMap)); + readVecShape = + readMap.compose(indexingMap.compose(state.getCanonicalVecShape())); } - // } - auto shape = linalgOp.getShape(opOperand); - SmallVector indices(shape.size(), zero); - Value readValue = b.create( - loc, readType, opOperand->get(), indices, map); - // Not all ops support 0-d vectors, extract the scalar for now. + auto readType = + VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get())); + SmallVector indices(linalgOp.getShape(opOperand).size(), zero); + + Operation *read = b.create( + loc, readType, opOperand->get(), indices, readMap); + read = state.maskOperation(b, read, linalgOp, maskMap); + Value readValue = read->getResult(0); + + // 3.b. If masked, set in-bounds to true. Masking guarantees that the access will + // be in-bounds. + if (auto maskOp = dyn_cast(read)) { + SmallVector inBounds(readType.getRank(), true); + cast(maskOp.getMaskableOp()) + .setInBoundsAttr(b.getBoolArrayAttr(inBounds)); + } + + // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) readValue = b.create(loc, readValue); - LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); + LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue + << "\n"); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -572,7 +844,7 @@ CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults); + return vectorizeLinalgYield(b, op, bvm, state, linalgOp, newResults); }; hooks.push_back(vectorizeYield); @@ -596,12 +868,13 @@ for (Operation &op : block->getOperations()) { VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { - LDBG("failed to vectorize: " << op); + LDBG("failed to vectorize: " << op << "\n"); return failure(); } if (result.status == VectorizationStatus::NewOp) { - LDBG("new vector op: " << *result.newOp;); - bvm.map(op.getResults(), result.newOp->getResults()); + Operation *maybeMaskedOp = state.maskOperation(b, result.newOp, linalgOp); + LDBG("New vector op: " << *maybeMaskedOp << "\n"); + bvm.map(op.getResults(), maybeMaskedOp->getResults()); } } @@ -612,7 +885,7 @@ // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator"); + LDBG("reduction precondition failed: no reduction iterator\n"); return failure(); } for (OpOperand *opOperand : op.getDpsInitOperands()) { @@ -622,19 +895,67 @@ Operation *reduceOp = matchLinalgReduction(opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed"); + LDBG("reduction precondition failed: reduction detection failed\n"); return failure(); } } return success(); } -static LogicalResult vectorizeStaticLinalgOpPrecondition( - linalg::LinalgOp op, - ArrayRef customPreconditions) { +static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { + // TODO: Masking only supports dynamic generic ops without reductions for now. + if (!isElementwise(op) && + llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) { + return itType != utils::IteratorType::parallel; + })) + return failure(); + + // TODO: 0-d vectors are not supported yet. + if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) { + return map.isEmpty() || map.getResults().empty(); + })) + return failure(); + + LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); + return success(); +} + +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition( + LinalgOp linalgOp, ArrayRef inputVectorSizes) { + // Check API contract for input vector sizes. + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == linalgOp.getNumLoops() && + "Input vector sizes don't match the number of loops"); + assert(!ShapedType::isDynamicShape(inputVectorSizes) && + "Input vector sizes can't have dynamic dimensions"); + assert(llvm::all_of( + llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes), + [](std::tuple sizePair) { + int64_t staticSize = std::get<0>(sizePair); + int64_t inputSize = std::get<1>(sizePair); + return ShapedType::isDynamic(staticSize) || + staticSize <= inputSize; + }) && + "Input vector sizes must be smaller or equal than iteration space " + "static sizes"); + } + + // TODO: Masking is only supported for dynamic shapes so input vector sizes + // must be empty if the op is not dynamic. + if (!linalgOp.hasDynamicShape() && !inputVectorSizes.empty()) + return failure(); + + if (linalgOp.hasDynamicShape() && + failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) + return failure(); + + SmallVector customPreconditions; + + // Register CustomVectorizationPrecondition for extractOp. + customPreconditions.push_back(tensorExtractVectorizationPrecondition); // All types in the body should be a supported element type for VectorType. - for (Operation &innerOp : op->getRegion(0).front()) { + for (Operation &innerOp : linalgOp->getRegion(0).front()) { // Check if any custom hook can vectorize the inner op. if (llvm::any_of( customPreconditions, @@ -654,46 +975,49 @@ return failure(); } } - if (isElementwise(op)) + if (isElementwise(linalgOp)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... - if (isa(op.getOperation())) + if (isa(linalgOp.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (!allIndexingsAreProjectedPermutation(op)) { - LDBG("precondition failed: not projected permutations"); + if (!allIndexingsAreProjectedPermutation(linalgOp)) { + LDBG("precondition failed: not projected permutations\n"); return failure(); } - if (failed(reductionPreconditions(op))) { - LDBG("precondition failed: reduction preconditions"); + if (failed(reductionPreconditions(linalgOp))) { + LDBG("precondition failed: reduction preconditions\n"); return failure(); } return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { - // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) { - LDBG("precondition failed: dynamic shape"); +/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` +/// are used to vectorize this operation. `inputVectorSizes` must match the rank +/// of the iteration space of the operation and the sizes must be smaller or +/// equal than their counterpart interation space sizes, if static. +/// `inputVectorShapes` also allows the vectorization of operations with dynamic +/// shapes. +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes) { + LDBG("Attempting to vectorize:\n" << linalgOp << "\n"); + LDBG("Input vector sizes: "); + LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes))) return failure(); - } - - SmallVector customPreconditions; - // Register CustomVectorizationPrecondition for extractOp. - customPreconditions.push_back(tensorExtractVectorizationPrecondition); - - return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions); -} - -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, - LinalgOp linalgOp) { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) + // Initialize vectorization state. + VectorizationState state; + if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) { + LDBG("Vectorization state couldn't be initialized\n"); return failure(); + } SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic @@ -704,8 +1028,13 @@ } else { if (failed(vectorizeLinalgOpPrecondition(linalgOp))) return failure(); - LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); - if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) + LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to + // 'OpBuilder' when it is passed over to some methods like + // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op + // within these methods, the actual rewriter won't be notified and we will + // end up with read-after-free issues! + if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) return failure(); } @@ -1201,7 +1530,7 @@ if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp); + << *firstOp << ", second op: " << *secondOp << "\n"); return true; } for (auto v : values) { @@ -1214,7 +1543,7 @@ (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp); + << ", second op: " << *secondOp << "\n"); return true; } } @@ -1250,14 +1579,14 @@ !viewOrAlloc.getDefiningOp()) return failure(); - LDBG(viewOrAlloc); + LDBG(viewOrAlloc << "\n"); // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); if (!subViewOp) return failure(); Value subView = subViewOp.getResult(); - LDBG("with subView " << subView); + LDBG("with subView " << subView << "\n"); // Find the copy into `subView` without interleaved uses. memref::CopyOp copyOp; @@ -1266,7 +1595,7 @@ assert(newCopyOp.getTarget().getType().isa()); if (newCopyOp.getTarget() != subView) continue; - LDBG("copy candidate " << *newCopyOp); + LDBG("copy candidate " << *newCopyOp << "\n"); if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) continue; copyOp = newCopyOp; @@ -1275,7 +1604,7 @@ } if (!copyOp) return failure(); - LDBG("with copy " << *copyOp); + LDBG("with copy " << *copyOp << "\n"); // Find the fill into `viewOrAlloc` without interleaved uses before the // copy. @@ -1285,7 +1614,7 @@ assert(newFillOp.output().getType().isa()); if (newFillOp.output() != viewOrAlloc) continue; - LDBG("fill candidate " << *newFillOp); + LDBG("fill candidate " << *newFillOp << "\n"); if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) continue; maybeFillOp = newFillOp; @@ -1296,7 +1625,7 @@ if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) return failure(); if (maybeFillOp) - LDBG("with maybeFillOp " << *maybeFillOp); + LDBG("with maybeFillOp " << *maybeFillOp << "\n"); // `in` is the subview that memref.copy reads. Replace it. Value in = copyOp.getSource(); diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -5,6 +5,8 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR DEPENDS + MLIRMaskableOpInterfaceIncGen + MLIRMaskingOpInterfaceIncGen MLIRVectorOpsIncGen MLIRVectorOpsEnumsIncGen diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -447,6 +447,15 @@ p << " : " << getVector().getType() << " into " << getDest().getType(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. +Type ReductionOp::getExpectedMaskType() { + auto vecType = getVectorType(); + return vecType.cloneWith(llvm::None, + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector) { @@ -3464,6 +3473,14 @@ [&](Twine t) { return emitOpError(t); }); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type TransferReadOp::getExpectedMaskType() { + return inferTransferReadMaskType(getVectorType(), getPermutationMap()); +} + template static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: @@ -3906,6 +3923,14 @@ [&](Twine t) { return emitOpError(t); }); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. +Type TransferWriteOp::getExpectedMaskType() { + return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); +} + /// Fold: /// ``` /// %t1 = ... @@ -5380,9 +5405,10 @@ "expects result type to match maskable operation result type"); // Mask checks. - if (getMask().getType() != maskableOp.getExpectedMaskType()) - return emitOpError("expects a ") << maskableOp.getExpectedMaskType() - << " mask for the maskable operation"; + Type expectedMaskType = maskableOp.getExpectedMaskType(); + if (getMask().getType() != expectedMaskType) + return emitOpError("expects a ") + << expectedMaskType << " mask for the maskable operation"; // Passthru checks. Value passthru = getPassthru(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -109,15 +109,6 @@ } }; -/// Populates instances of `MaskOpRewritePattern` to lower masked operations -/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and -/// not its nested `MaskableOpInterface`. -void populateVectorMaskLoweringPatternsForSideEffectingOps( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - struct LowerVectorMaskPass : public vector::impl::LowerVectorMaskPassBase { using Base::Base; @@ -141,6 +132,15 @@ } // namespace +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + std::unique_ptr mlir::vector::createLowerVectorMaskPass() { return std::make_unique(); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -338,6 +338,18 @@ return llvm::None; } +/// Extracts the permuted position where the given input index resides. +/// Returns `llvm::None` if the input index is projected. Asserts on +/// non-projected permutation maps. +Optional +AffineMap::getPermutedPositionOfProjectedPermutation(unsigned input) const { + assert(isProjectedPermutation() && "invalid projected permutation request"); + for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) + if (getDimPosition(i) == input) + return i; + return llvm::None; +} + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s -// ----- - // CHECK-LABEL: contraction_dot func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { @@ -130,7 +128,7 @@ // CHECK-LABEL: func @generic_output_transpose func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<32x8xf32>) { + %C: memref<32x8xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8330,6 +8330,7 @@ ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen", ":LinalgUtils", + ":MaskableOpInterface", ":MathDialect", ":MemRefDialect", ":MemRefTransforms",