diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -591,6 +591,36 @@ return result; }] >, + InterfaceMethod< + /*desc=*/[{ + Given a dimension of the iteration space of a Linalg operation, finds an + operand in the operation that is defined on such dimension. Returns + whether such operand was found or not. If found, also returns the + operand value and the dimension position within the operand. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"mapIterationSpaceDimToOperandDim", + /*args=*/(ins "unsigned":$dimPos, + "::mlir::Value &":$operand, + "unsigned &":$operandDimPos), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Retrieve the operand and its dimension position from the first + // operand with a permutation map that is defined on such dimension. + for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) { + if (idxMap.isProjectedPermutation()) { + if (auto mayOperandDim = idxMap.getResultPosition( + getAffineDimExpr(dimPos, idxMap.getContext()))) { + operand = $_op->getOperand(i); + operandDimPos = *mayOperandDim; + return success(); + } + } + } + + return failure(); + }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1115,4 +1115,45 @@ }]; } +def MaskedVectorizeOp : Op, + TransformOpInterface]> { + let description = [{ + Vectorize the target ops, which must be Linalg ops, with masked vectors + of the specified size. + + The vector sizes can be either static or dynamic (SSA values). In case of + SSA values, the handle must be mapped to exactly one payload op with + exactly one index-typed result. + + #### Return modes: + + This operation produces a definite failure if the dynamic vector sizes (SSA + values) do not satify the constraints mentioned above. It produces a + silenceable failure if at least one target op is not a Linalg op or fails to + vectorize. + }]; + + let arguments = (ins PDL_Operation:$target, + Variadic:$vector_sizes, + DefaultValuedOptionalAttr: + $static_vector_sizes); + let results = (outs); + let assemblyFormat = [{ + $target + `vector_sizes` custom($vector_sizes, + $static_vector_sizes) + attr-dict + }]; + + let extraClassDeclaration = [{ + // TODO: applyToOne. + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedVectorSizes(); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -344,8 +344,14 @@ FailureOr promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options); -/// Emit a suitable vector form for a Linalg op with fully static shape. -LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp, +/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` +/// are used to vectorize this operation. `inputVectorSizes` must match the rank +/// of the iteration space of the operation and the sizes must be smaller or +/// equal than their counterpart interation space sizes, if static. +/// `inputVectorShapes` also allows the vectorization of operations with dynamic +/// shapes. +LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes = {}, bool vectorizeNDExtract = false); /// Emit a suitable vector form for a Copy op with fully static shape. @@ -372,8 +378,10 @@ LinalgPromotionOptions options); /// Return success if the operation can be vectorized. -LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, - bool vectorizeNDExtract = false); +LogicalResult +vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + ArrayRef inputVectorSizes = {}, + bool vectorizeNDExtract = false); //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -450,13 +450,13 @@ /// source tensor and thus correspond to "dim-1" broadcasting. llvm::SetVector computeBroadcastedUnitDims(); - /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the + /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the /// `broadcastedDims` dimensions in the dstShape are broadcasted. - /// This requires (and asserts) that the broadcast is free of dim-1 + /// This requires (and asserts) that the broadcast is free of dim-1 /// broadcasting. /// Since vector.broadcast only allows expanding leading dimensions, an extra /// vector.transpose may be inserted to make the broadcast possible. - /// `value`, `dstShape` and `broadcastedDims` must be properly specified or + /// `value`, `dstShape` and `broadcastedDims` must be properly specified or /// the helper will assert. This means: /// 1. `dstShape` must not be empty. /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] @@ -1179,6 +1179,8 @@ let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } +// TODO: Tighten semantics so that masks and inbounds can't be used +// simultaneously within the same transfer op. def Vector_TransferReadOp : Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, @@ -1394,6 +1396,8 @@ let hasVerifier = 1; } +// TODO: Tighten semantics so that masks and inbounds can't be used +// simultaneously within the same transfer op. def Vector_TransferWriteOp : Vector_Op<"transfer_write", [ DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td --- a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td @@ -31,7 +31,9 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return mlir::isa($_op->getParentOp()); + mlir::Operation *parentOp = $_op->getParentOp(); + return parentOp && + mlir::isa(parentOp); }]>, InterfaceMethod< /*desc=*/"Returns the MaskingOpInterface masking this operation.", @@ -54,18 +56,14 @@ return false; }]>, InterfaceMethod< - /*desc=*/"Returns the mask type expected by this operation. It requires " - "the operation to be vectorized.", - /*retTy=*/"mlir::VectorType", + /*desc=*/"Returns the mask type expected by this operation. Mostly used" + " for verification purposes. It requires the operation to be " + "vectorized.", + /*retTy=*/"mlir::Type", /*methodName=*/"getExpectedMaskType", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Default implementation is only aimed for operations that implement the - // `getVectorType()` method. - return $_op.getVectorType().cloneWith(/*shape=*/std::nullopt, - IntegerType::get($_op.getContext(), /*width=*/1)); - }]>, + /*defaultImplementation=*/"">, ]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -22,6 +22,12 @@ /// Creates an instance of the `vector.mask` lowering pass. std::unique_ptr createLowerVectorMaskPass(); +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1825,7 +1826,8 @@ LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); - return vectorize(rewriter, linalgOp, vectorizeNDExtract); + return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, + vectorizeNDExtract); } private: @@ -1873,6 +1875,85 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// MaskedVectorizeOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( + mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + IRRewriter rewriter(getContext()); + ArrayRef targets = state.getPayloadOps(getTarget()); + if (targets.empty()) + return DiagnosedSilenceableFailure::success(); + + SmallVector vectorSizes; + for (OpFoldResult sz : getMixedVectorSizes()) { + if (sz.is()) { + auto attr = sz.get(); + vectorSizes.push_back(attr.cast().getInt()); + continue; + } + + ArrayRef szPayloads = state.getPayloadOps(sz.get()); + if (szPayloads.size() != 1) { + auto diag = this->emitOpError( + "requires vector size handle that is mapped to 1 payload op"); + diag.attachNote(sz.get().getLoc()) + << "mapped to " << szPayloads.size() << " payload ops"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + + Operation *szPayloadOp = szPayloads[0]; + if (szPayloadOp->getNumResults() != 1 || + !szPayloadOp->getResult(0).getType().isIndex()) { + auto diag = this->emitOpError( + "requires vector size payload op with 1 index result"); + diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + + IntegerAttr attr; + if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) { + auto diag = this->emitOpError("requires constant vector size"); + diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + + vectorSizes.push_back(attr.getInt()); + } + + // TODO: Check that the correct number of vectorSizes was provided. + + for (Operation *target : targets) { + auto linalgOp = dyn_cast(target); + if (!linalgOp) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + diag << "cannot vectorize non-Linalg op"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + diag << "failed to vectorize op"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::MaskedVectorizeOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + onlyReadsHandle(getVectorSizes(), effects); +} + +SmallVector MaskedVectorizeOp::getMixedVectorSizes() { + OpBuilder b(getContext()); + return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -11,25 +11,20 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -65,6 +60,266 @@ return res; } +/// Contains the vectorization state and related methods used across the +/// vectorization process of a given operation. +struct VectorizationState { + VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {} + + /// Initializes the vectorization state, including the computation of the + /// canonical vector shape for vectorization. + LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes); + + /// Returns the canonical vector shape used to vectorize the iteration space. + ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } + + /// Masks an operation with the canonical vector mask if the operation needs + /// masking. Returns the masked operation or the original operation if masking + /// is not needed. If provided, the canonical mask for this operation is + /// permuted using `maybeMaskingMap`. + Operation *maskOperation(RewriterBase &rewriter, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap = std::nullopt); + +private: + /// Initializes the iteration space static sizes using the Linalg op + /// information. This may become more complicated in the future. + void initIterSpaceStaticSizes(LinalgOp linalgOp) { + iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); + } + + /// Generates 'tensor.dim' operations for all the dynamic dimensions of the + /// iteration space to be vectorized and store them in + /// `iterSpaceDynamicSizes`. + LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter, + LinalgOp linalgOp); + + /// Create or retrieve an existing mask value to mask `opToMask` in the + /// canonical vector iteration space. If `maybeMaskingMap` the mask is + /// permuted using that permutation map. If a new mask is created, it will be + /// cached for future users. + Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap); + + // Holds the compile-time static sizes of the iteration space to vectorize. + // Dynamic dimensions are represented using ShapedType::kDynamicSize. + SmallVector iterSpaceStaticSizes; + + /// Holds the runtime sizes of the iteration spaces to vectorize. Static + /// dimensions are represented with a empty value. + SmallVector iterSpaceDynamicSizes; + + /// Holds the canonical vector shape used to vectorize the iteration space. + SmallVector canonicalVecShape; + + /// Holds the active masks for permutations of the canonical vector iteration + /// space. + DenseMap activeMaskCache; + + /// Global vectorization guard for the incoming rewriter. It's initialized + /// when the vectorization state is initialized. + OpBuilder::InsertionGuard rewriterGuard; +}; + +/// Generates 'tensor.dim' operations for all the dynamic dimensions of the +/// iteration space to be vectorized and store them in +/// `iterSpaceDynamicSizes`. +LogicalResult +VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter, + LinalgOp linalgOp) { + // TODO: Support 0-d vectors. + for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { + if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { + // Add a empty value for static dimensions. + iterSpaceDynamicSizes.push_back(Value()); + continue; + } + + // Find an operand defined on this dimension of the iteration space to + // extract the runtime dimension size. + Value operand; + unsigned operandDimPos; + if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand, + operandDimPos))) + return failure(); + + Value dynamicDim = linalgOp.hasTensorSemantics() + ? (Value)rewriter.create( + linalgOp.getLoc(), operand, operandDimPos) + : (Value)rewriter.create( + linalgOp.getLoc(), operand, operandDimPos); + iterSpaceDynamicSizes.push_back(dynamicDim); + } + + return success(); +} + +/// Initializes the vectorization state, including the computation of the +/// canonical vector shape for vectorization. +// TODO: Move this to the constructor when we can remove the failure cases. +LogicalResult +VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes) { + // Initialize the insertion point. + rewriter.setInsertionPoint(linalgOp); + + if (!inputVectorSizes.empty()) { + // Get the canonical vector shape from the input vector sizes provided. This + // path should be taken to vectorize code with dynamic shapes and when using + // vector sizes greater than the iteration space sizes. + canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); + } else { + // Compute the canonical vector shape from the operation shape. If there are + // dynamic shapes, the operation won't be vectorized. + canonicalVecShape = linalgOp.getStaticLoopRanges(); + } + + LDBG("Canonical vector shape: "); + LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + // Initialize iteration space static sizes. + initIterSpaceStaticSizes(linalgOp); + + // Extract and register the runtime value of any potential dynamic shape + // needed to compute a mask during vectorization. + if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp))) + return failure(); + + if (ShapedType::isDynamicShape(canonicalVecShape)) + return failure(); + return success(); +} + +/// Create or retrieve an existing mask value to mask `opToMask` in the +/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted +/// using that permutation map. If a new mask is created, it will be cached for +/// future users. +Value VectorizationState::getOrCreateMaskFor( + RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, + Optional maybeMaskingMap) { + // No mask is needed if the operation is not maskable. + auto maskableOp = dyn_cast(opToMask); + if (!maskableOp) + return Value(); + + assert(!maskableOp.isMasked() && + "Masking an operation that is already masked"); + + // If no masking map was provided, use an identity map with the loop dims. + assert((!maybeMaskingMap || *maybeMaskingMap) && + "Unexpected null mask permutation map"); + AffineMap maskingMap = + maybeMaskingMap ? *maybeMaskingMap + : AffineMap::getMultiDimIdentityMap( + linalgOp.getNumLoops(), rewriter.getContext()); + + LDBG("Masking map: " << maskingMap << "\n"); + + // Return the active mask for the masking map of this operation if it was + // already created. + auto activeMaskIt = activeMaskCache.find(maskingMap); + if (activeMaskIt != activeMaskCache.end()) { + Value mask = activeMaskIt->second; + LDBG("Reusing mask: " << mask << "\n"); + return mask; + } + + // Compute permuted projection of the iteration space to be masked and the + // corresponding mask shape. If the resulting iteration space dimensions are + // static and identical to the mask shape, masking is not needed for this + // operation. + // TODO: Improve this check. Only projected permutation indexing maps are + // supported. + SmallVector permutedStaticSizes = + applyPermutationMap(maskingMap, ArrayRef(iterSpaceStaticSizes)); + SmallVector maskShape = + applyPermutationMap(maskingMap, ArrayRef(canonicalVecShape)); + LDBG("Mask shape: "); + LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + if (permutedStaticSizes == maskShape) { + LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); + activeMaskCache[maskingMap] = Value(); + return Value(); + } + + // Compute the mask upper bound values by combining the permuted iteration + // space static sizes and the dynamic values. + SmallVector permutedDynamicSizes = + applyPermutationMap(maskingMap, ArrayRef(iterSpaceDynamicSizes)); + SmallVector upperBounds; + for (auto [staticBound, dynBound] : + llvm::zip(permutedStaticSizes, permutedDynamicSizes)) + upperBounds.push_back(ShapedType::isDynamic(staticBound) + ? dynBound + : rewriter.create( + linalgOp.getLoc(), staticBound)); + + assert(!maskShape.empty() && !upperBounds.empty() && + "Masked 0-d vectors are not supported yet"); + + // Create the mask based on the dimension size values. + auto maskType = VectorType::get(maskShape, rewriter.getI1Type()); + Value mask = rewriter.create(linalgOp.getLoc(), + maskType, upperBounds); + LDBG("Creating new mask: " << mask << "\n"); + activeMaskCache[maskingMap] = mask; + return mask; +} + +/// Masks an operation with the canonical vector mask if the operation needs +/// masking. Returns the masked operation or the original operation if masking +/// is not needed. If provided, the canonical mask for this operation is +/// permuted using `maybeMaskingMap`. +Operation * +VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, + LinalgOp linalgOp, + Optional maybeMaskingMap) { + LDBG("Trying to mask: " << *opToMask << "\n"); + + // Create or retrieve mask for this operation. + Value mask = + getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap); + + if (!mask) { + LDBG("No mask required\n"); + return opToMask; + } + + // Wrap the operation with a new `vector.mask` and update D-U chain. + assert(opToMask && "Expected a valid operation to mask"); + auto opResults = opToMask->getResultTypes(); + auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) { + Block *insBlock = builder.getInsertionBlock(); + // Create a block, put an op in that block. Look for a utility. + // Maybe in conversion pattern rewriter. Way to avoid splice. + // Set insertion point. + insBlock->getOperations().splice( + insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask); + builder.create(loc, opToMask->getResults()); + }; + // TODO: Allow multiple results in vector.mask. + auto maskOp = + opResults.empty() + ? rewriter.create(opToMask->getLoc(), mask, + createRegionMask) + : rewriter.create(opToMask->getLoc(), + opToMask->getResultTypes().front(), + mask, createRegionMask); + + Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); + + for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults())) + rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), + maskOpTerminator); + + LDBG("Masked operation: " << *maskOp << "\n"); + return maskOp; +} + /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a /// projectedPermutation, compress the unused dimensions to serve as a /// permutation_map for a vector transfer operation. @@ -204,35 +459,44 @@ /// Return the produced value or null if no value is produced. // Note: this is a true builder that notifies the OpBuilder listener. // TODO: Consider moving as a static helper on the ReduceOp. -static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand) { - Operation *write; +static Value buildVectorWrite(RewriterBase &rewriter, Value value, + OpOperand *outputOperand, + VectorizationState &state) { Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); - ArrayRef shape = linalgOp.getShape(outputOperand); - auto vectorType = VectorType::get( - shape, getElementTypeOrSelf(outputOperand->get().getType())); + AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); + auto vectorType = + VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()), + getElementTypeOrSelf(outputOperand->get().getType())); + + Operation *write; if (vectorType.getRank() > 0) { - // 0-d case is still special: do not invert the reindexing map. - AffineMap map = - reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand)); - SmallVector transposeShape = - applyPermutationMap(inversePermutation(map), vectorType.getShape()); - assert(!transposeShape.empty() && "unexpected empty transpose shape"); - vectorType = VectorType::get(transposeShape, vectorType.getElementType()); + AffineMap writeMap = reindexIndexingMap(opOperandMap); SmallVector indices(linalgOp.getRank(outputOperand), - b.create(loc, 0)); - value = broadcastIfNeeded(b, value, vectorType.getShape()); - write = b.create( - loc, value, outputOperand->get(), indices, map); + rewriter.create(loc, 0)); + value = broadcastIfNeeded(rewriter, value, vectorType.getShape()); + write = rewriter.create( + loc, value, outputOperand->get(), indices, writeMap); } else { + // 0-d case is still special: do not invert the reindexing writeMap. if (!value.getType().isa()) - value = b.create(loc, vectorType, value); + value = rewriter.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); - write = b.create( + write = rewriter.create( loc, value, outputOperand->get(), ValueRange{}); } - LDBG("vectorized op: " << *write); + + write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); + + // If masked, set in-bounds to true. Masking guarantees that the access will + // be in-bounds. + if (auto maskOp = dyn_cast(write)) { + auto maskedWriteOp = cast(maskOp.getMaskableOp()); + SmallVector inBounds(maskedWriteOp.getVectorType().getRank(), true); + maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } + + LDBG("vectorized op: " << *write << "\n"); if (!write->getResults().empty()) return write->getResult(0); return Value(); @@ -259,20 +523,22 @@ /// CustomVectorizationHook. static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, - const BlockAndValueMapping &bvm, LinalgOp linalgOp, - SmallVectorImpl &newResults) { + const BlockAndValueMapping &bvm, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); if (!yieldOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; - for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) { + for (const auto &output : llvm::enumerate(yieldOp.getValues())) { // TODO: Scan for an opportunity for reuse. // TODO: use a map. - Value vectorValue = bvm.lookup(outputs.value()); - Value newResult = buildVectorWrite( - rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); + Value vectorValue = bvm.lookup(output.value()); + Value newResult = + buildVectorWrite(rewriter, vectorValue, + linalgOp.getDpsInitOperand(output.index()), state); if (newResult) newResults.push_back(newResult); } + return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } @@ -464,7 +730,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { - LDBG("vectorize op " << *op); + LDBG("vectorize op " << *op << "\n"); // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -561,8 +827,10 @@ /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, + LinalgOp linalgOp, SmallVectorImpl &newResults) { + LDBG("Vectorizing operation as linalg generic\n"); Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -575,11 +843,6 @@ if (linalgOp.getNumDpsInits() == 0) return failure(); - // TODO: the common vector shape is equal to the static loop sizes only when - // all indexing maps are projected permutations. For convs and stencils the - // logic will need to evolve. - SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); - // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = rewriter.create(loc, 0); @@ -589,35 +852,60 @@ bvm.map(bbarg, opOperand->get()); continue; } - VectorType readType; - AffineMap map; - // TODO: can we keep this simplification? - // if (linalgOp.getShape(&opOperand).empty()) { - // readType = VectorType::get({}, bbarg.getType()); - // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) { - map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); - readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + + // 3.a. Convert the indexing map for this input/output to a transfer read + // permutation map and masking map. + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + + // Remove zeros from indexing map to use it as masking map. + SmallVector zeroPos; + auto results = indexingMap.getResults(); + for (auto result : llvm::enumerate(results)) { + if (result.value().isa()) { + zeroPos.push_back(result.index()); + } + } + AffineMap maskingMap = indexingMap.dropResults(zeroPos); + + AffineMap readMap; + SmallVector readVecShape; + if (linalgOp.isDpsInput(opOperand)) { + // 3.a.i. For input reads we use the canonical vector shape. + readMap = inverseAndBroadcastProjectedPermutation(indexingMap); + readVecShape = llvm::to_vector(state.getCanonicalVecShape()); } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); + // 3.a.ii. For output reads (iteration-carried dependence, e.g., + // reductions), the vector shape is computed by mapping the canonical + // vector shape to the output domain and back to the canonical domain. + readMap = inversePermutation(reindexIndexingMap(indexingMap)); + readVecShape = + readMap.compose(indexingMap.compose(state.getCanonicalVecShape())); } - // } - auto shape = linalgOp.getShape(opOperand); - SmallVector indices(shape.size(), zero); - Value readValue = rewriter.create( - loc, readType, opOperand->get(), indices, map); - // Not all ops support 0-d vectors, extract the scalar for now. + auto readType = + VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get())); + SmallVector indices(linalgOp.getShape(opOperand).size(), zero); + + Operation *read = rewriter.create( + loc, readType, opOperand->get(), indices, readMap); + read = state.maskOperation(rewriter, read, linalgOp, maskingMap); + Value readValue = read->getResult(0); + + // 3.b. If masked, set in-bounds to true. Masking guarantees that the access + // will be in-bounds. + if (auto maskOp = dyn_cast(read)) { + SmallVector inBounds(readType.getRank(), true); + cast(maskOp.getMaskableOp()) + .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } + + // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) readValue = rewriter.create(loc, readValue); - LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); + LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue + << "\n"); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -627,7 +915,7 @@ CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults); + return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); }; hooks.push_back(vectorizeYield); @@ -652,12 +940,14 @@ VectorizationResult result = vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { - LDBG("failed to vectorize: " << op); + LDBG("failed to vectorize: " << op << "\n"); return failure(); } if (result.status == VectorizationStatus::NewOp) { - LDBG("new vector op: " << *result.newOp;); - bvm.map(op.getResults(), result.newOp->getResults()); + Operation *maybeMaskedOp = + state.maskOperation(rewriter, result.newOp, linalgOp); + LDBG("New vector op: " << *maybeMaskedOp << "\n"); + bvm.map(op.getResults(), maybeMaskedOp->getResults()); } } @@ -668,7 +958,7 @@ // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator"); + LDBG("reduction precondition failed: no reduction iterator\n"); return failure(); } for (OpOperand *opOperand : op.getDpsInitOperands()) { @@ -678,20 +968,69 @@ Operation *reduceOp = matchLinalgReduction(opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed"); + LDBG("reduction precondition failed: reduction detection failed\n"); return failure(); } } return success(); } -static LogicalResult vectorizeStaticLinalgOpPrecondition( - linalg::LinalgOp op, - ArrayRef customPreconditions, - bool vectorizeNDExtract) { +static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { + // TODO: Masking only supports dynamic generic ops without reductions for now. + if (!isElementwise(op) && + llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) { + return itType != utils::IteratorType::parallel; + })) + return failure(); + + // TODO: 0-d vectors are not supported yet. + if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) { + return map.isEmpty() || map.getResults().empty(); + })) + return failure(); + + LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); + return success(); +} + +LogicalResult +mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + ArrayRef inputVectorSizes, + bool vectorizeNDExtract) { + // Check API contract for input vector sizes. + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == linalgOp.getNumLoops() && + "Input vector sizes don't match the number of loops"); + assert(!ShapedType::isDynamicShape(inputVectorSizes) && + "Input vector sizes can't have dynamic dimensions"); + assert(llvm::all_of( + llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes), + [](std::tuple sizePair) { + int64_t staticSize = std::get<0>(sizePair); + int64_t inputSize = std::get<1>(sizePair); + return ShapedType::isDynamic(staticSize) || + staticSize <= inputSize; + }) && + "Input vector sizes must be smaller or equal than iteration space " + "static sizes"); + } + + // TODO: Masking is only supported for dynamic shapes so input vector sizes + // must be empty if the op is not dynamic. + if (!linalgOp.hasDynamicShape() && !inputVectorSizes.empty()) + return failure(); + + if (linalgOp.hasDynamicShape() && + failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) + return failure(); + + SmallVector customPreconditions; + + // Register CustomVectorizationPrecondition for extractOp. + customPreconditions.push_back(tensorExtractVectorizationPrecondition); // All types in the body should be a supported element type for VectorType. - for (Operation &innerOp : op->getRegion(0).front()) { + for (Operation &innerOp : linalgOp->getRegion(0).front()) { // Check if any custom hook can vectorize the inner op. if (llvm::any_of( customPreconditions, @@ -712,50 +1051,52 @@ return failure(); } } - if (isElementwise(op)) + if (isElementwise(linalgOp)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... - if (isa(op.getOperation())) + if (isa(linalgOp.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (!allIndexingsAreProjectedPermutation(op)) { - LDBG("precondition failed: not projected permutations"); + if (!allIndexingsAreProjectedPermutation(linalgOp)) { + LDBG("precondition failed: not projected permutations\n"); return failure(); } - if (failed(reductionPreconditions(op))) { - LDBG("precondition failed: reduction preconditions"); + if (failed(reductionPreconditions(linalgOp))) { + LDBG("precondition failed: reduction preconditions\n"); return failure(); } return success(); } -LogicalResult -mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, - bool vectorizeNDExtract) { - // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) { - LDBG("precondition failed: dynamic shape"); - return failure(); - } - - SmallVector customPreconditions; - - // Register CustomVectorizationPrecondition for extractOp. - customPreconditions.push_back(tensorExtractVectorizationPrecondition); - - return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions, - vectorizeNDExtract); -} - +/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` +/// are used to vectorize this operation. `inputVectorSizes` must match the rank +/// of the iteration space of the operation and the sizes must be smaller or +/// equal than their counterpart interation space sizes, if static. +/// `inputVectorShapes` also allows the vectorization of operations with dynamic +/// shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef inputVectorSizes, bool vectorizeNDExtract) { - if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract))) + LDBG("Attempting to vectorize:\n" << linalgOp << "\n"); + LDBG("Input vector sizes: "); + LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, + vectorizeNDExtract))) return failure(); + // Initialize vectorization state. + VectorizationState state(rewriter); + if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) { + LDBG("Vectorization state couldn't be initialized\n"); + return failure(); + } + SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic // features. Will require stride/dilation attributes inference. @@ -763,10 +1104,16 @@ if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); } else { - if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract))) + if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, + vectorizeNDExtract))) return failure(); - LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); - if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) + LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to + // 'OpBuilder' when it is passed over to some methods like + // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op + // within these methods, the actual rewriter won't be notified and we will + // end up with read-after-free issues! + if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) return failure(); } @@ -1262,7 +1609,7 @@ if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp); + << *firstOp << ", second op: " << *secondOp << "\n"); return true; } for (auto v : values) { @@ -1275,7 +1622,7 @@ (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp); + << ", second op: " << *secondOp << "\n"); return true; } } diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -5,6 +5,8 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR DEPENDS + MLIRMaskableOpInterfaceIncGen + MLIRMaskingOpInterfaceIncGen MLIRVectorOpsIncGen MLIRVectorOpsEnumsIncGen diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -447,6 +447,15 @@ p << " : " << getVector().getType() << " into " << getDest().getType(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. +Type ReductionOp::getExpectedMaskType() { + auto vecType = getVectorType(); + return vecType.cloneWith(std::nullopt, + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector) { @@ -3461,6 +3470,14 @@ [&](Twine t) { return emitOpError(t); }); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type TransferReadOp::getExpectedMaskType() { + return inferTransferReadMaskType(getVectorType(), getPermutationMap()); +} + template static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: @@ -3903,6 +3920,14 @@ [&](Twine t) { return emitOpError(t); }); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. +Type TransferWriteOp::getExpectedMaskType() { + return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); +} + /// Fold: /// ``` /// %t1 = ... @@ -5377,9 +5402,10 @@ "expects result type to match maskable operation result type"); // Mask checks. - if (getMask().getType() != maskableOp.getExpectedMaskType()) - return emitOpError("expects a ") << maskableOp.getExpectedMaskType() - << " mask for the maskable operation"; + Type expectedMaskType = maskableOp.getExpectedMaskType(); + if (getMask().getType() != expectedMaskType) + return emitOpError("expects a ") + << expectedMaskType << " mask for the maskable operation"; // Passthru checks. Value passthru = getPassthru(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -109,15 +109,6 @@ } }; -/// Populates instances of `MaskOpRewritePattern` to lower masked operations -/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and -/// not its nested `MaskableOpInterface`. -void populateVectorMaskLoweringPatternsForSideEffectingOps( - RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - struct LowerVectorMaskPass : public vector::impl::LowerVectorMaskPassBase { using Base::Base; @@ -141,6 +132,15 @@ } // namespace +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + std::unique_ptr mlir::vector::createLowerVectorMaskPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s -// ----- - // CHECK-LABEL: contraction_dot func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { @@ -130,7 +128,7 @@ // CHECK-LABEL: func @generic_output_transpose func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<32x8xf32>) { + %C: memref<32x8xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> @@ -1608,3 +1606,147 @@ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 } + +// ----- + +func.func @vectorize_dynamic_identity(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @vectorize_dynamic_identity +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor +// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1> +// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<4xf32> +// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4] +} + +// ----- + +func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @vectorize_dynamic_1d_broadcast +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor +// CHECK: %[[VAL_7:.*]] = vector.transfer_read %{{.*}} {permutation_map = #{{.*}}} : tensor, vector<4xf32> +// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1> +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_7]], %[[VAL_10]] : vector<4xf32> +// CHECK: %[[VAL_14:.*]] = vector.mask %{{.*}} { vector.transfer_write %[[VAL_13]], {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4] +} + +// ----- + +func.func @vectorize_dynamic_2d_transpose(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @vectorize_dynamic_2d_transpose +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor +// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_4]] : vector<8x4xi1> +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor, vector<4x8xf32> } : vector<8x4xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_14:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %[[VAL_16]], %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor } : vector<4x8xi1> -> tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8] +} + +// ----- + +func.func @vectorize_dynamic_generic_2d_broadcast(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @vectorize_dynamic_generic_2d_broadcast +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor +// CHECK: %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]] : vector<8xi1> +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor, vector<4x8xf32> } : vector<8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32> +// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor } : vector<4x8xi1> -> tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8] +} + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8308,6 +8308,7 @@ ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen", ":LinalgUtils", + ":MaskableOpInterface", ":MathDialect", ":MemRefDialect", ":MemRefTransforms",