diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -328,8 +328,9 @@ FailureOr promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options); -/// Emit a suitable vector form for a Linalg op with fully static shape. -LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp); +/// Emit a suitable vector form for a Linalg op. +LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef vecSizesForMaskedDims = {}); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -170,6 +170,12 @@ /// Fails when called on a non-permutation. unsigned getPermutedPosition(unsigned input) const; + /// Extracts the permuted position where the given input index resides. + /// Returns `llvm::None` if the input index is projected. Fails when called on + /// a non-projected-permutation. + Optional + getProjectedPermutationPermutedPosition(unsigned input) const; + /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { return llvm::any_of(getResults(), [&](AffineExpr e) { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1700,6 +1700,7 @@ LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); + // TODO: Pass vector sizes for masked dims. return vectorize(rewriter, linalgOp); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -22,12 +22,14 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" @@ -65,6 +67,204 @@ return res; } +/// Contains the vectorization state and related methods used across the +/// vectorization process of a given operation. +struct VectorizationState { + /// Initializes the vectorization state, including the computation of the + /// canonical vector shape for vectorization. + LogicalResult initState(OpBuilder &builder, LinalgOp linalgOp, + ArrayRef vecSizesForMaskedDims); + + /// Returns the canonical vector shape used to vectorize the iteration space. + ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } + /// Masks an operation with the canonical vector mask if the operation needs + /// masking. Returns the masked operation or the original operation if masking + /// is not needed. If provided, the canonical mask for this operation is + /// permuted using `maybeMaskPermMap`. + Operation *maskOperation(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskPermMap = llvm::None); + + /// Holds the active masks for permutations of the canonical vector iteration + /// space. + DenseMap activeMaskCache; + /// Holds the values of the sizes of the masked dimensions. + SmallVector maskedDimSizeValues; + +private: + /// Generates 'tensor.dim' operations for all the dynamic dimensions of the + /// given operation to be vectorized and store them in `maskedDimSizeValues`. + LogicalResult extractDynamicVectorDimValues(OpBuilder &builder, + LinalgOp linalgOp); + + /// Create or retrieve an existing mask value to mask `opToMask` in the + /// canonical vector iteration space. If `maybeMaskPermMap` the mask is + /// permuted using that permutation map. If a new mask is created, it will be + /// cached for future users. + Value getOrCreateMaskFor(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskPermMap); + + /// Holds the canonical vector shape used to vectorize the iteration space. + SmallVector canonicalVecShape; +}; + +/// Given a dimension of the iteration space of an operation, finds an operand +/// in the operation that is defined on such dimension. Returns the operand and +/// the operand dimension. +static void mapIterationSpaceDimToOperandDim(unsigned dim, LinalgOp linalgOp, + Value &operand, + unsigned &operandDim) { + // Retrieve the operand and its dimension from the first operand with a + // permutation map that is defined on such dimension. + for (auto &en : llvm::enumerate(linalgOp.getIndexingMapsArray())) { + AffineMap idxMap = en.value(); + if (idxMap.isProjectedPermutation()) { + auto mayOperandDim = idxMap.getProjectedPermutationPermutedPosition(dim); + if (mayOperandDim) { + operand = linalgOp->getOperand(en.index()); + operandDim = *mayOperandDim; + return; + } + } + } + + llvm_unreachable("Unsupported linalg op"); +} + +/// Generates 'tensor.dim' operations for all the dynamic dimensions of the +/// given operation to be vectorized and store them in `maskedDimSizeValues`. +LogicalResult +VectorizationState::extractDynamicVectorDimValues(OpBuilder &builder, + LinalgOp linalgOp) { + assert(canonicalVecShape.empty() && "The canonical vector shape is empty"); + for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { + Value operand; + unsigned operandDim = std::numeric_limits::max(); + mapIterationSpaceDimToOperandDim(vecDim, linalgOp, operand, operandDim); + assert(operand && operandDim != std::numeric_limits::max() && + "Masked dimension mapping didn't happen"); + + auto dynDim = + builder.create(linalgOp.getLoc(), operand, operandDim); + maskedDimSizeValues.push_back(dynDim); + } + + return success(); +} + +/// Initializes the vectorization state, including the computation of the +/// canonical vector shape for vectorization. +LogicalResult +VectorizationState::initState(OpBuilder &builder, LinalgOp linalgOp, + ArrayRef vecSizesForMaskedDims) { + assert((vecSizesForMaskedDims.empty() || + vecSizesForMaskedDims.size() == linalgOp.getNumLoops()) && + "Sizes for masked dims don't match iteration space dims"); + + if (linalgOp.hasDynamicShape()) { + // TODO: Only support fully dynamic ops for now. + if (!llvm::all_of(vecSizesForMaskedDims, ShapedType::isDynamic)) + return failure(); + + canonicalVecShape.append(vecSizesForMaskedDims.begin(), + vecSizesForMaskedDims.end()); + return extractDynamicVectorDimValues(builder, linalgOp); + } else { + canonicalVecShape = linalgOp.getStaticLoopRanges(); + } + + return success(); +} + +/// Create or retrieve an existing mask value to mask `opToMask` in the +/// canonical vector iteration space. If `maybeMaskPermMap` the mask is permuted +/// using that permutation map. If a new mask is created, it will be cached for +/// future users. +Value VectorizationState::getOrCreateMaskFor( + OpBuilder &builder, Operation *opToMask, LinalgOp linalgOp, + Optional maybeMaskPermMap) { + // No mask is needed if no masked dim sizes provided. + if (maskedDimSizeValues.empty()) + return Value(); + + // No mask is needed if the operation is not maskable. + auto maskableOp = dyn_cast(opToMask); + if (!maskableOp) + return Value(); + + assert(!maskableOp.isMasked() && + "Masking an operation that is already masked"); + + // If no mask permutation map was provided, use an identity map with the loop + // dims. + AffineMap maskPermMap = + maybeMaskPermMap ? *maybeMaskPermMap + : AffineMap::getMultiDimIdentityMap( + linalgOp.getNumLoops(), builder.getContext()); + + // Return active mask for the indexing map of this operand if it was already + // created. + if (activeMaskCache.count(maskPermMap)) { + Value mask = activeMaskCache[maskPermMap]; + LDBG("Reusing mask: " << mask << "\n"); + return mask; + } + + // Premute the dimension values and vector sizes so that they align with the + // dimension order of the mask. + SmallVector permVecDimValues = + applyPermutationMap(maskPermMap, ArrayRef(maskedDimSizeValues)); + SmallVector permVecShape = + applyPermutationMap(maskPermMap, ArrayRef(canonicalVecShape)); + + // Create the mask based on the runtime value of the dimensions to be + // vectorized. + auto maskType = VectorType::get(permVecShape, builder.getI1Type()); + Value mask = builder.create(linalgOp.getLoc(), maskType, + permVecDimValues); + LDBG("Creating new mask: " << mask << "\n"); + activeMaskCache[maskPermMap] = mask; + return mask; +} + +/// Masks an operation with the canonical vector mask if the operation needs +/// masking. Returns the masked operation or the original operation if masking +/// is not needed. If provided, the canonical mask for this operation is +/// permuted using `maybeMaskPermMap`. +Operation * +VectorizationState::maskOperation(OpBuilder &builder, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskPermMap) { + // Create or retrieve mask for this operation. + Value mask = + getOrCreateMaskFor(builder, opToMask, linalgOp, maybeMaskPermMap); + + if (!mask) { + LDBG("No mask required for: " << *opToMask << "\n"); + return opToMask; + } + + // Wrap `opToMask` with a new `vector.mask` and update D-U chain. + assert(opToMask && "Expected a valid operation to mask"); + auto maskOp = builder.create( + opToMask->getLoc(), opToMask->getResultTypes().front(), mask, + [opToMask](OpBuilder &builder, Location loc) { + Block *insBlock = builder.getInsertionBlock(); + insBlock->getOperations().splice( + insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask); + builder.create(loc, opToMask->getResults()); + }); + + Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); + for (auto &en : llvm::enumerate(opToMask->getResults())) + en.value().replaceAllUsesExcept(maskOp.getResult(en.index()), + maskOpTerminator); + + LDBG("Masked operation: " << *maskOp << "\n"); + return maskOp; +} + /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a /// projectedPermutation, compress the unused dimensions to serve as a /// permutation_map for a vector transfer operation. @@ -198,38 +398,41 @@ /// to all `0`; where `outputOperand` is an output operand of the LinalgOp /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand) { +static Operation *buildVectorWrite(OpBuilder &b, Value value, + OpOperand *outputOperand, + VectorizationState &state) { Operation *write; Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); - ArrayRef shape = linalgOp.getShape(outputOperand); - auto vectorType = VectorType::get( - shape, getElementTypeOrSelf(outputOperand->get().getType())); + AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); + auto vectorType = + VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()), + getElementTypeOrSelf(outputOperand->get().getType())); + if (vectorType.getRank() > 0) { - // 0-d case is still special: do not invert the reindexing map. - AffineMap map = - reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand)); - SmallVector transposeShape = - applyPermutationMap(inversePermutation(map), vectorType.getShape()); - assert(!transposeShape.empty() && "unexpected empty transpose shape"); - vectorType = VectorType::get(transposeShape, vectorType.getElementType()); + AffineMap writeMap = reindexIndexingMap(opOperandMap); SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); + // If masked, set in-bounds to true. Masking guarantees that the access will + // be in-bounds. + SmallVector inBounds; + if (!state.maskedDimSizeValues.empty()) + inBounds.append(vectorType.getRank(), true); + write = b.create(loc, value, outputOperand->get(), - indices, map); + indices, writeMap, + ArrayRef(inBounds)); } else { + // 0-d case is still special: do not invert the reindexing writeMap. if (!value.getType().isa()) value = b.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); write = b.create(loc, value, outputOperand->get(), ValueRange{}); } - LDBG("vectorized op: " << *write); - if (!write->getResults().empty()) - return write->getResult(0); - return Value(); + LDBG("vectorized op: " << *write << "\n"); + return write; } // Custom vectorization precondition function type. This is intented to be used @@ -253,20 +456,26 @@ /// CustomVectorizationHook. static VectorizationResult vectorizeLinalgYield(OpBuilder &b, Operation *op, - const BlockAndValueMapping &bvm, LinalgOp linalgOp, - SmallVectorImpl &newResults) { + const BlockAndValueMapping &bvm, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); if (!yieldOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; - for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) { + for (const auto &output : llvm::enumerate(yieldOp.getValues())) { // TODO: Scan for an opportunity for reuse. // TODO: use a map. - Value vectorValue = bvm.lookup(outputs.value()); - Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); - if (newResult) - newResults.push_back(newResult); + Value vectorValue = bvm.lookup(output.value()); + OpOperand *opOperand = linalgOp.getDpsInitOperand(output.index()); + Operation *write = buildVectorWrite( + b, vectorValue, linalgOp.getDpsInitOperand(output.index()), state); + // TODO: We mask the transfer.transfer_write here because this op is + // special-cased. A linalg.yield may produced multiple vector.transfer_write + // ops and can't be mapped using BlockAndValueMapping. + AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(opOperand); + write = state.maskOperation(b, write, linalgOp, opOperandMap); + newResults.append(write->result_begin(), write->result_end()); } + return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } @@ -410,7 +619,7 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { - LDBG("vectorize op " << *op); + LDBG("vectorize op " << *op << "\n"); // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -506,8 +715,10 @@ /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(OpBuilder &b, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { + LDBG("Vectorizing operation as linalg generic\n"); Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -520,11 +731,6 @@ if (linalgOp.getNumDpsInits() == 0) return failure(); - // TODO: the common vector shape is equal to the static loop sizes only when - // all indexing maps are projected permutations. For convs and stencils the - // logic will need to evolve. - SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); - // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); @@ -534,35 +740,47 @@ bvm.map(bbarg, opOperand->get()); continue; } - VectorType readType; - AffineMap map; - // TODO: can we keep this simplification? - // if (linalgOp.getShape(&opOperand).empty()) { - // readType = VectorType::get({}, bbarg.getType()); - // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) { - map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); - readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + + // Convert the indexing map for this input/output to a transfer read + // permutation map. For input reads we use the canonical vector shape. For + // output reads (iteration-carried dependence, e.g., reductions), the vector + // shape is computed by mapping the canonical vector shape to the output + // domain and back to the canonical domain. + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + AffineMap readMap; + ArrayRef readVecShape; + if (linalgOp.isDpsInput(opOperand)) { + readMap = inverseAndBroadcastProjectedPermutation(indexingMap); + readVecShape = state.getCanonicalVecShape(); } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); + readMap = inversePermutation(reindexIndexingMap(indexingMap)); + readVecShape = + readMap.compose(indexingMap.compose(state.getCanonicalVecShape())); } - // } - auto shape = linalgOp.getShape(opOperand); - SmallVector indices(shape.size(), zero); - Value readValue = b.create( - loc, readType, opOperand->get(), indices, map); + auto readType = + VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get())); + SmallVector indices(linalgOp.getShape(opOperand).size(), zero); + + // If masked, set in-bounds to true. Masking guarantees that the access will + // be in-bounds. + SmallVector inBounds; + if (!state.maskedDimSizeValues.empty()) + inBounds.append(readType.getRank(), true); + + Operation *read = b.create( + loc, readType, opOperand->get(), indices, readMap, + ArrayRef(inBounds)); + read = state.maskOperation(b, read, linalgOp, indexingMap); + Value readValue = read->getResult(0); + // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) readValue = b.create(loc, readValue); - LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); + LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue + << "\n"); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -572,7 +790,7 @@ CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults); + return vectorizeLinalgYield(b, op, bvm, state, linalgOp, newResults); }; hooks.push_back(vectorizeYield); @@ -596,12 +814,13 @@ for (Operation &op : block->getOperations()) { VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { - LDBG("failed to vectorize: " << op); + LDBG("failed to vectorize: " << op << "\n"); return failure(); } if (result.status == VectorizationStatus::NewOp) { - LDBG("new vector op: " << *result.newOp;); - bvm.map(op.getResults(), result.newOp->getResults()); + Operation *maybeMaskedOp = state.maskOperation(b, result.newOp, linalgOp); + LDBG("New vector op: " << *maybeMaskedOp << "\n"); + bvm.map(op.getResults(), maybeMaskedOp->getResults()); } } @@ -612,7 +831,7 @@ // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator"); + LDBG("reduction precondition failed: no reduction iterator\n"); return failure(); } for (OpOperand *opOperand : op.getDpsInitOperands()) { @@ -622,19 +841,43 @@ Operation *reduceOp = matchLinalgReduction(opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed"); + LDBG("reduction precondition failed: reduction detection failed\n"); return failure(); } } return success(); } -static LogicalResult vectorizeStaticLinalgOpPrecondition( - linalg::LinalgOp op, - ArrayRef customPreconditions) { +static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { + // TODO: Only dynamic generic ops are supported for now. + if (!isa(op) && !isa(op)) + return failure(); + + // TODO: Only ops with fully dynamic tensors are supported for now. + if (llvm::any_of(op.getOperation()->getOpOperands(), + [](OpOperand &opOperand) { + TensorType operandType = + opOperand.get().getType().dyn_cast(); + return !operandType || operandType.hasStaticShape(); + })) + return failure(); + + LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); + return success(); +} + +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { + if (linalgOp.hasDynamicShape() && + failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) + return failure(); + + SmallVector customPreconditions; + + // Register CustomVectorizationPrecondition for extractOp. + customPreconditions.push_back(tensorExtractVectorizationPrecondition); // All types in the body should be a supported element type for VectorType. - for (Operation &innerOp : op->getRegion(0).front()) { + for (Operation &innerOp : linalgOp->getRegion(0).front()) { // Check if any custom hook can vectorize the inner op. if (llvm::any_of( customPreconditions, @@ -654,46 +897,39 @@ return failure(); } } - if (isElementwise(op)) + if (isElementwise(linalgOp)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... - if (isa(op.getOperation())) + if (isa(linalgOp.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (!allIndexingsAreProjectedPermutation(op)) { - LDBG("precondition failed: not projected permutations"); + if (!allIndexingsAreProjectedPermutation(linalgOp)) { + LDBG("precondition failed: not projected permutations\n"); return failure(); } - if (failed(reductionPreconditions(op))) { - LDBG("precondition failed: reduction preconditions"); + if (failed(reductionPreconditions(linalgOp))) { + LDBG("precondition failed: reduction preconditions\n"); return failure(); } return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { - // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) { - LDBG("precondition failed: dynamic shape"); +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef vecSizesForMaskedDims) { + LDBG("Attempting to vectorize:\n" << linalgOp << "\n"); + if (failed(vectorizeLinalgOpPrecondition(linalgOp))) return failure(); - } - SmallVector customPreconditions; - - // Register CustomVectorizationPrecondition for extractOp. - customPreconditions.push_back(tensorExtractVectorizationPrecondition); - - return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions); -} - -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, - LinalgOp linalgOp) { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) + // Initialize vectorization state. + VectorizationState state; + if (failed(state.initState(rewriter, linalgOp, vecSizesForMaskedDims))) { + LDBG("Vectorization state couldn't be initialized\n"); return failure(); + } SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic @@ -704,8 +940,14 @@ } else { if (failed(vectorizeLinalgOpPrecondition(linalgOp))) return failure(); - LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); - if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) + LDBG("Vectorize generic by broadcasting to a common shape: \n" + << linalgOp << "\n"); + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to + // 'OpBuilder' when it is passed over to some methods like + // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op + // within these methods, the actual rewriter won't be notified and we will + // end up with read-after-free issues! + if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) return failure(); } @@ -1201,7 +1443,7 @@ if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp); + << *firstOp << ", second op: " << *secondOp << "\n"); return true; } for (auto v : values) { @@ -1214,7 +1456,7 @@ (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp); + << ", second op: " << *secondOp << "\n"); return true; } } @@ -1250,14 +1492,14 @@ !viewOrAlloc.getDefiningOp()) return failure(); - LDBG(viewOrAlloc); + LDBG(viewOrAlloc << "\n"); // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); if (!subViewOp) return failure(); Value subView = subViewOp.getResult(); - LDBG("with subView " << subView); + LDBG("with subView " << subView << "\n"); // Find the copy into `subView` without interleaved uses. memref::CopyOp copyOp; @@ -1266,7 +1508,7 @@ assert(newCopyOp.getTarget().getType().isa()); if (newCopyOp.getTarget() != subView) continue; - LDBG("copy candidate " << *newCopyOp); + LDBG("copy candidate " << *newCopyOp << "\n"); if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) continue; copyOp = newCopyOp; @@ -1275,7 +1517,7 @@ } if (!copyOp) return failure(); - LDBG("with copy " << *copyOp); + LDBG("with copy " << *copyOp << "\n"); // Find the fill into `viewOrAlloc` without interleaved uses before the // copy. @@ -1285,7 +1527,7 @@ assert(newFillOp.output().getType().isa()); if (newFillOp.output() != viewOrAlloc) continue; - LDBG("fill candidate " << *newFillOp); + LDBG("fill candidate " << *newFillOp << "\n"); if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) continue; maybeFillOp = newFillOp; @@ -1296,7 +1538,7 @@ if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) return failure(); if (maybeFillOp) - LDBG("with maybeFillOp " << *maybeFillOp); + LDBG("with maybeFillOp " << *maybeFillOp << "\n"); // `in` is the subview that memref.copy reads. Replace it. Value in = copyOp.getSource(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -336,6 +336,18 @@ llvm_unreachable("incorrect permutation request"); } +/// Extracts the permuted position where the given input index resides. +/// Returns `llvm::None` if the input index is projected. Fails when called on +/// a non-projected-permutation. +Optional +AffineMap::getProjectedPermutationPermutedPosition(unsigned input) const { + assert(isProjectedPermutation() && "invalid projected permutation request"); + for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) + if (getDimPosition(i) == input) + return i; + return llvm::None; +} + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s -// ----- - // CHECK-LABEL: contraction_dot func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { @@ -130,7 +128,7 @@ // CHECK-LABEL: func @generic_output_transpose func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<32x8xf32>) { + %C: memref<32x8xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>