diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -44,12 +44,12 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops. +/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); -/// Populates patterns for vectorizing low-D convolution ops. This is a step in +/// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, @@ -91,7 +91,7 @@ /// canonicalizations of named ops into another named op. void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); -/// Populates the given list with patterns to bufferize linalg ops. +/// Populate the given list with patterns to bufferize linalg ops. void populateLinalgBufferizePatterns( bufferization::BufferizeTypeConverter &converter, RewritePatternSet &patterns); @@ -124,7 +124,7 @@ return *this; } - /// Function that allows the caller to control when to stop fusion. Once a + /// Function to allow the caller to control when to stop fusion. Once a /// producer is deemed fusable with the consumer (structurally), this callback /// can be used to abort the fusion based on non-structural constraints. This /// is the hook for cost models to control the amount of fusion done. @@ -149,7 +149,7 @@ /// more fusion opportunities. void populatePushReshapeOpsPatterns(RewritePatternSet &patterns); -/// Performs standalone tiling of a single LinalgOp by `tileSizes`. +/// Perform standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify /// the new ordering of the loop nest. The length of `interchangeVector` @@ -157,7 +157,7 @@ /// An empty vector is interpreted as the identity permutation and the /// transformation returns early. /// -/// Returns a struct containing the tiled loops in the specified order +/// Return a struct containing the tiled loops in the specified order /// and the cloned op if successful, llvm::None otherwise. /// /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by @@ -237,7 +237,7 @@ const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions); -/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts +/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts /// the index accesses of `op`. This is an in-place transformation controlled by /// `interchangeVector`. An empty vector is interpreted as the identity /// permutation and the transformation returns early. @@ -246,12 +246,15 @@ /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). -void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp, - ArrayRef interchangeVector); +FailureOr interchangeGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ArrayRef interchangeVector); -/// Creates a GenericOp from the given named operation `namedOp`. Assumes -/// `namedOp` is not a GenericOp and has a region builder. -GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp); +/// Create a GenericOp from the given named operation `namedOp` and replace +/// namedOp. +/// Return failure if `namedOp` is a GenericOp or misses a region builder. +FailureOr generalizeNamedOp(RewriterBase &rewriter, + LinalgOp namedOp); /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the @@ -346,7 +349,7 @@ } }; -/// Creates a new buffer using the `allocationFn` provided. The size of this +/// Create a new buffer using the `allocationFn` provided. The size of this /// buffer is the smallest constant bounding size along each dimension that can /// be computed for the size of the result of `subView`. Returns the allocated /// buffer as `fullLocalView` and the view that matches the size of the result @@ -360,7 +363,7 @@ const AllocBufferCallbackFn &allocationFn, DataLayout &layout); -/// Promotes the `subViews` into a new buffer allocated at the insertion point +/// Promote the `subViews` into a new buffer allocated at the insertion point /// `b`. Promotion occurs in 3 steps: /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). /// 2. Take a full view on the buffer. @@ -368,24 +371,23 @@ /// Infers statically sized buffers from subViews unless `dynamicBuffers` is /// true. /// -/// Returns the modified linalg op (the modification happens in place) as well +/// Return the modified linalg op (the modification happens in place) as well /// as all the copy ops created. FailureOr promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options); /// Emit a suitable vector form for a Linalg op with fully static shape. -LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, - SmallVectorImpl &newResults); +LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp); -/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. +/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. FailureOr linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp); -/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. +/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`. FailureOr linalgOpToParallelLoops(PatternRewriter &rewriter, LinalgOp linalgOp); -/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. +/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`. FailureOr linalgOpToAffineLoops(PatternRewriter &rewriter, LinalgOp linalgOp); @@ -393,28 +395,10 @@ // Preconditions that ensure the corresponding transformation succeeds and can // be applied as a rewrite pattern. //===----------------------------------------------------------------------===// -/// Emits a `generic` operation with the `indexing_maps` and `iterator_types` -/// permutated according to `permutation`. -LogicalResult -interchangeGenericOpPrecondition(GenericOp genericOp, - ArrayRef interchangeVector); - -/// Generalize named operations to generic operations. -LogicalResult generalizeNamedOpPrecondition(Operation *op); - -/// Promote std.subviews feeding linalg operations. +/// Promote memref.subviews feeding linalg-on-buffers operations. LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options); -/// Return success if the operation can be vectorized. -LogicalResult vectorizeLinalgOpPrecondition(Operation *op); - -/// Return success if `op` can be vectorized assuming it is static. This allows -/// checking if an op will be vectorizable once all the dimensions are folded to -/// static values. -/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes. -LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op); - //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// @@ -610,7 +594,7 @@ RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); -/// Base pattern that applied the tiling transformation specified by `options`. +/// Base pattern that applies the tiling transformation specified by `options`. /// Abort and return failure in 2 cases: /// 1. if the tiling specification is invalid and tiling fails to occur. /// 2. if tiling occurs but `options.paddingValueComputationFunction` is set @@ -812,9 +796,9 @@ }; /// -/// Linalg generic interchage pattern. +/// Linalg generic interchange pattern. /// -/// Apply the `interchange` transformation as a pattern. +/// Apply the `interchange` transformation on a RewriterBase. /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `interchange` for more details. struct GenericOpInterchangePattern : public OpRewritePattern { @@ -909,13 +893,11 @@ /// /// Linalg vectorization patterns. /// -/// Apply the `vectorizeLinalgOp` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. - /// Empty for now, used for SFINAE purposes only. struct LinalgVectorizationOptions {}; +/// `filter` controls LinalgTransformMarker matching and update when specified. +/// See `vectorizeLinalgOp` for more details. struct LinalgBaseVectorizationPattern : public RewritePattern { /// MatchAnyOpTag-based constructor with a mandatory `filter`. LinalgBaseVectorizationPattern(MLIRContext *context, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -29,7 +29,7 @@ using namespace mlir; using namespace mlir::linalg; -LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) { +static LogicalResult generalizeNamedOpPrecondition(Operation *op) { LinalgOp namedOp = dyn_cast(op); // Check if the operation is a LinalgOp but not a GenericOp. if (!namedOp || isa(op)) @@ -40,8 +40,11 @@ return success(); } -GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter, - LinalgOp namedOp) { +FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, + LinalgOp namedOp) { + if (failed(generalizeNamedOpPrecondition(namedOp))) + return rewriter.notifyMatchFailure(namedOp, "preconditions not met"); + SmallVector inputOperands = namedOp.getInputOperands(); SmallVector outputOperands = namedOp.getOutputOperands(); SmallVector indexingMaps = namedOp.getIndexingMaps(); @@ -58,6 +61,7 @@ outputOperands, indexingMaps, iterators); rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), genericOp.region().begin()); + rewriter.replaceOp(namedOp, genericOp->getResults()); return genericOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -30,8 +31,9 @@ using namespace mlir; using namespace mlir::linalg; -LogicalResult mlir::linalg::interchangeGenericOpPrecondition( - GenericOp genericOp, ArrayRef interchangeVector) { +static LogicalResult +interchangeGenericOpPrecondition(GenericOp genericOp, + ArrayRef interchangeVector) { // Interchange vector must be non-empty and match the number of loops. if (interchangeVector.empty() || genericOp.getNumLoops() != interchangeVector.size()) @@ -43,31 +45,38 @@ return success(); } -void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter, - GenericOp genericOp, - ArrayRef interchangeVector) { - // 1. Compute the inverse permutation map. +FailureOr +mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef interchangeVector) { + if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) + return rewriter.notifyMatchFailure(genericOp, "preconditions not met"); + + // 1. Compute the inverse permutation map, it must be non-null since the + // preconditions are satisfied. MLIRContext *context = genericOp.getContext(); AffineMap permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); - assert(permutationMap && "expected permutation to be invertible"); - assert(interchangeVector.size() == genericOp.getNumLoops() && - "expected interchange vector to have entry for every loop"); + assert(permutationMap && "unexpected null map"); + + // Start a guarded inplace update. + rewriter.startRootUpdate(genericOp); + auto guard = + llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); }); // 2. Compute the interchanged indexing maps. - SmallVector newIndexingMaps; + SmallVector newIndexingMaps; for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { AffineMap m = genericOp.getTiedIndexingMap(opOperand); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); - newIndexingMaps.push_back(AffineMapAttr::get(m)); + newIndexingMaps.push_back(m); } genericOp->setAttr(getIndexingMapsAttrName(), - ArrayAttr::get(context, newIndexingMaps)); + rewriter.getAffineMapArrayAttr(newIndexingMaps)); // 3. Compute the interchanged iterator types. ArrayRef itTypes = genericOp.iterator_types().getValue(); - SmallVector itTypesVector; + SmallVector itTypesVector; llvm::append_range(itTypesVector, itTypes); SmallVector permutation(interchangeVector.begin(), interchangeVector.end()); @@ -91,4 +100,6 @@ indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); } } + + return genericOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -137,7 +137,7 @@ struct LinalgNamedOpConversionPass : public LinalgNamedOpConversionBase { LinalgNamedOpConversionPass() = default; - LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {} + LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default; void runOnOperation() override { Operation *op = getOperation(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -623,16 +623,14 @@ GenericOp genericOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, genericOp))) return failure(); - if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) + + FailureOr transformedOp = + interchangeGenericOp(rewriter, genericOp, interchangeVector); + if (failed(transformedOp)) return failure(); - // TODO: figure out how this interplays with named ops. In particular this - // should break the named op property. - rewriter.updateRootInPlace(genericOp, [&]() { - interchangeGenericOp(rewriter, genericOp, interchangeVector); - // New filter if specified. - filter.replaceLinalgTransformationFilter(rewriter, genericOp); - }); + // New filter if specified. + filter.replaceLinalgTransformationFilter(rewriter, genericOp); return success(); } @@ -652,12 +650,10 @@ Operation *op, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - if (failed(generalizeNamedOpPrecondition(op))) + FailureOr genericOp = generalizeNamedOp(rewriter, op); + if (failed(genericOp)) return failure(); - - GenericOp genericOp = generalizeNamedOp(rewriter, op); - rewriter.replaceOp(op, genericOp.getResults()); - filter.replaceLinalgTransformationFilter(rewriter, genericOp); + filter.replaceLinalgTransformationFilter(rewriter, *genericOp); return success(); } @@ -708,19 +704,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { + // TODO: Interface-based rewrite. LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - SmallVector newResults; - if (failed(vectorizeLinalgOp(rewriter, op, newResults))) + if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - if (!newResults.empty()) - rewriter.replaceOp(op, newResults); - else - rewriter.eraseOp(op); - return success(); + return vectorize(rewriter, linalgOp); } LogicalResult mlir::linalg::applyStagedPatterns( @@ -758,8 +748,8 @@ return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } -/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize -/// with pad_val) and GenericOp (to copy contents). +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to +/// initialize with pad_val) and GenericOp (to copy contents). LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -597,8 +597,7 @@ return success(); } -LogicalResult -mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { +static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { if (isElementwise(op)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. @@ -620,8 +619,7 @@ return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { - auto linalgOp = cast(op); +static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { // All types must be static shape to go to vector. if (linalgOp.hasDynamicShape()) { LDBG("precondition failed: dynamic shape"); @@ -630,31 +628,32 @@ return vectorizeStaticLinalgOpPrecondition(linalgOp); } -LogicalResult -mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op, - SmallVectorImpl &newResults) { - if (failed(vectorizeLinalgOpPrecondition(op))) +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, + LinalgOp linalgOp) { + if (failed(vectorizeLinalgOpPrecondition(linalgOp))) return failure(); - auto linalgOp = cast(op); - - // TODO: isaConvolutionOpInterface that can also infer from generic features. - // But we will still need stride/dilation attributes that will be annoying to - // reverse-engineer... - if (auto convOp = dyn_cast(op)) { - FailureOr resultOrFail = vectorizeConvolution(b, convOp); - if (failed(resultOrFail)) + SmallVector results; + // TODO: isaConvolutionOpInterface that can also infer from generic + // features. Will require stride/dilation attributes inference. + if (auto convOp = dyn_cast(linalgOp.getOperation())) { + LDBG("Vectorize as a conv: " << linalgOp); + FailureOr convOr = vectorizeConvolution(rewriter, convOp); + if (failed(convOr)) + return failure(); + llvm::append_range(results, (*convOr)->getResults()); + } else { + LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); + if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) return failure(); - Operation *newOp = *resultOrFail; - llvm::append_range(newResults, newOp->getResults()); - return success(); } - LDBG("" - << "Vectorize linalg op as a generic by broadcasting to " - "maximal common shape: " - << *op); - return vectorizeAsLinalgGeneric(b, linalgOp, newResults); + if (!results.empty()) + rewriter.replaceOp(linalgOp, results); + else + rewriter.eraseOp(linalgOp); + + return success(); } //----------------------------------------------------------------------------// @@ -666,8 +665,9 @@ return attr.cast().getInt(); } -/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs -/// are converted to ConstantIndexOps. Other attribute types are not supported. +/// Given an ArrayRef of OpFoldResults, return a vector of Values. +/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are +/// not supported. static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, ArrayRef ofrs) { SmallVector result; @@ -691,9 +691,9 @@ GenericPadTensorOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit = 1) : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {} - /// Vectorize the copying of a PadTensorOp's source. This is possible if each - /// dimension size is statically know in the source type or the result type - /// (or both). + /// Vectorize the copying of a PadTensorOp's source. This is possible if + /// each dimension size is statically know in the source type or the result + /// type (or both). static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, PadTensorOp padOp, Value dest) { auto sourceType = padOp.getSourceType(); @@ -718,13 +718,14 @@ for (unsigned i = 0; i < sourceType.getRank(); ++i) { if (!sourceType.isDynamicDim(i)) { vecShape.push_back(sourceType.getDimSize(i)); - // Source shape is statically known: Neither read nor write are out-of- - // bounds. + // Source shape is statically known: Neither read nor write are + // out-of- bounds. readInBounds.push_back(true); writeInBounds.push_back(true); } else if (!resultType.isDynamicDim(i)) { - // Source shape is not statically known, but result shape is. Vectorize - // with size of result shape. This may be larger than the source size. + // Source shape is not statically known, but result shape is. + // Vectorize with size of result shape. This may be larger than the + // source size. vecShape.push_back(resultType.getDimSize(i)); // Read may be out-of-bounds because the result size could be larger // than the source size. @@ -749,8 +750,8 @@ padOp.getLoc(), vecType, padOp.source(), readIndices, padValue, ArrayRef{readInBounds}); - // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire - // tensor, write directly to the FillOp's operand. + // If `dest` is a FillOp and the TransferWriteOp would overwrite the + // entire tensor, write directly to the FillOp's operand. if (llvm::equal(vecShape, resultType.getShape()) && llvm::all_of(writeInBounds, [](bool b) { return b; })) if (auto fill = dest.getDefiningOp()) @@ -766,8 +767,8 @@ } }; -/// Base pattern for rewriting PadTensorOps whose result is consumed by a given -/// operation type OpTy. +/// Base pattern for rewriting PadTensorOps whose result is consumed by a +/// given operation type OpTy. template struct VectorizePadTensorOpUserPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -837,10 +838,10 @@ }; /// Rewrite use of PadTensorOp result in TransferWriteOp. -/// This pattern rewrites TransferWriteOps that write to a padded tensor value, -/// where the same amount of padding is immediately removed again after the -/// write. In such cases, the TransferWriteOp can write to the non-padded tensor -/// value and apply out-of-bounds masking. E.g.: +/// This pattern rewrites TransferWriteOps that write to a padded tensor +/// value, where the same amount of padding is immediately removed again after +/// the write. In such cases, the TransferWriteOp can write to the non-padded +/// tensor value and apply out-of-bounds masking. E.g.: /// ``` /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] /// : tensor<...> to tensor @@ -854,17 +855,19 @@ /// ``` /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] /// : tensor<...> to tensor -/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, +/// tensor /// ``` /// Note: It is important that the ExtractSliceOp %r resizes the result of the -/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even -/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from -/// %r's old dimensions. +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an +/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ +/// from %r's old dimensions. /// /// This rewrite is possible if: /// - Low padding is static 0. /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This -/// ExtractSliceOp trims the same amount of padding that was added beforehand. +/// ExtractSliceOp trims the same amount of padding that was added +/// beforehand. /// - Single, scalar padding value. struct PadTensorOpVectorizationWithTransferWritePattern : public VectorizePadTensorOpUserPattern { @@ -922,8 +925,8 @@ /// sizes may turn out to be equal at runtime. bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const { - // If the input to PadTensorOp is a CastOp, try with with both CastOp result - // and CastOp operand. + // If the input to PadTensorOp is a CastOp, try with with both CastOp + // result and CastOp operand. if (auto castOp = beforePadding.getDefiningOp()) if (hasSameTensorSize(castOp.source(), afterTrimming)) return true; @@ -950,8 +953,9 @@ if (t1.getNumDynamicDims() == 0) return true; - // All dynamic sizes must be the same. The only supported case at the moment - // is when `beforePadding` is an ExtractSliceOp (or a cast thereof). + // All dynamic sizes must be the same. The only supported case at the + // moment is when `beforePadding` is an ExtractSliceOp (or a cast + // thereof). // Apart from CastOp, only ExtractSliceOp is supported. auto beforeSlice = beforePadding.getDefiningOp(); @@ -1062,7 +1066,8 @@ // InsertSliceOp. rewriter.setInsertionPoint(insertOp); - // Generate TransferReadOp: Read entire source tensor and add high padding. + // Generate TransferReadOp: Read entire source tensor and add high + // padding. SmallVector readIndices( vecRank, rewriter.create(padOp.getLoc(), 0)); auto read = rewriter.create( @@ -1224,9 +1229,9 @@ // Forwarding patterns //----------------------------------------------------------------------------// -/// Check whether there is any interleaved use of any `values` between `firstOp` -/// and `secondOp`. Conservatively return `true` if any op or value is in a -/// different block. +/// Check whether there is any interleaved use of any `values` between +/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value +/// is in a different block. static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values) { if (firstOp->getBlock() != secondOp->getBlock() || @@ -1252,7 +1257,8 @@ return false; } -/// Return the unique subview use of `v` if it is indeed unique, null otherwise. +/// Return the unique subview use of `v` if it is indeed unique, null +/// otherwise. static memref::SubViewOp getSubViewUseIfUnique(Value v) { memref::SubViewOp subViewOp; for (auto &u : v.getUses()) { @@ -1307,7 +1313,8 @@ return failure(); LDBG("with copy " << *copyOp); - // Find the fill into `viewOrAlloc` without interleaved uses before the copy. + // Find the fill into `viewOrAlloc` without interleaved uses before the + // copy. FillOp maybeFillOp; for (auto &u : viewOrAlloc.getUses()) { if (auto newFillOp = dyn_cast(u.getOwner())) { @@ -1468,7 +1475,8 @@ /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} /// ``` /// kw is always unrolled. - /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. + /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is + /// > 1. FailureOr conv() { if (!valid) return failure(); @@ -1483,7 +1491,8 @@ Value zero = builder.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. - // When strideW == 1, we can batch the contiguous loads and avoid unrolling + // When strideW == 1, we can batch the contiguous loads and avoid + // unrolling int64_t wSizeStep = strideW == 1 ? wSize : 1; Type lhsEltType = lhsShapedType.getElementType(); @@ -1500,7 +1509,8 @@ VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType); VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType); - // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0]. + // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, + // 0]. Value lhs = builder.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. @@ -1591,7 +1601,8 @@ /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// ``` /// kw is always unrolled. - /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. + /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is + /// > 1. FailureOr dilatedConv() { if (!valid) return failure(); @@ -1605,7 +1616,8 @@ Value zero = builder.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. - // When strideW == 1, we can batch the contiguous loads and avoid unrolling + // When strideW == 1, we can batch the contiguous loads and avoid + // unrolling int64_t wSizeStep = strideW == 1 ? wSize : 1; Type lhsEltType = lhsShapedType.getElementType(); @@ -1621,7 +1633,8 @@ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType); VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType); - // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0]. + // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, + // 0]. Value lhs = builder.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c} @ [0, 0].