diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -112,9 +112,10 @@ linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyDecomposePass. -// TODO: atm this is applied to all supported ops. If/when we need finer control -// this should be exposed with an opName + filter and a proper pattern. -std::unique_ptr> createLinalgStrategyDecomposePass(); +// TODO: if/when we need finer control add an `opName` parameter. +std::unique_ptr> +createLinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); /// Create a LinalgStrategyInterchangePass. std::unique_ptr> diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -286,8 +286,7 @@ ]; } -// TODO: atm this is applied to all supported ops. If/when we need finer control -// this should be exposed with an opName + filter and a proper pattern. +// TODO: if/when we need finer control add an anchorOp option. def LinalgStrategyDecomposePass : FunctionPass<"linalg-strategy-decompose-pass"> { let summary = "Configurable pass to apply pattern-based generalization."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -49,12 +49,6 @@ MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); -/// Populates patterns to decompose high-D convolution ops into low-D ones. This -/// is a step in progressive lowering for convolution ops, afterwards we can -/// vectorize the low-D convolution ops. -void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Populates patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. @@ -1178,6 +1172,16 @@ RewritePatternSet &patterns, LinalgTransformationFilter filter = LinalgTransformationFilter()); +/// Linalg decompose convolutions patterns + +/// Populates patterns to decompose high-D convolution ops into low-D ones. This +/// is a step in progressive lowering for convolution ops, afterwards we can +/// vectorize the low-D convolution ops. +void populateDecomposeConvolutionPatterns( + RewritePatternSet &patterns, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + /// Linalg distribution patterns // /// Populates `patterns` with patterns to distribute linalg.tiled_loop. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -191,16 +191,21 @@ LinalgStrategyDecomposePass() = default; + LinalgStrategyDecomposePass(LinalgTransformationFilter filter) + : filter(filter) {} + void runOnFunction() override { auto funcOp = getFunction(); if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) return; RewritePatternSet decompositionPattern(funcOp.getContext()); - populateDecomposeConvolutionPatterns(decompositionPattern); + populateDecomposeConvolutionPatterns(decompositionPattern, filter); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(decompositionPattern)))) signalPassFailure(); } + + LinalgTransformationFilter filter; }; /// Configurable pass to apply pattern-based linalg generalization. @@ -478,12 +483,12 @@ LinalgTransformationFilter filter) { return std::make_unique(opName, filter); } + /// Create a LinalgStrategyDecomposePass. -// TODO: atm this is applied to all supported ops. If/when we need finer control -// this should be exposed with an opName + filter and a proper pattern. +// TODO: if/when we need finer control add an `opName` parameter. std::unique_ptr> -mlir::createLinalgStrategyDecomposePass() { - return std::make_unique(); +mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) { + return std::make_unique(filter); } /// Create a LinalgStrategyInterchangePass. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -929,31 +929,36 @@ /// convolution ops. struct DownscaleSizeOneWindowed2DConvolution final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + DownscaleSizeOneWindowed2DConvolution( + MLIRContext *context, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), filter(filter) {} LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const override { - auto linalgOp = cast(*convOp); - if (linalgOp.hasBufferSemantics()) + if (failed(filter.checkAndNotify(rewriter, convOp))) + return failure(); + if (convOp.hasBufferSemantics()) return failure(); // To be implemented Value input = convOp.inputs().front(); - Value filter = convOp.inputs().back(); + Value kernel = convOp.inputs().back(); Value output = convOp.outputs().front(); auto inputType = input.getType().dyn_cast(); - auto filterType = filter.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); - auto filterShape = filterType.getShape(); + auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t fhSize = filterShape[0], fwSize = filterShape[1]; + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; int64_t ohSize = outputShape[1], owSize = outputShape[2]; - bool removeH = (fhSize == 1 && ohSize == 1); - bool removeW = (fwSize == 1 && owSize == 1); + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) return failure(); @@ -962,8 +967,8 @@ using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - RankedTensorType newFilterType = - RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); RankedTensorType newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); @@ -971,8 +976,8 @@ Location loc = convOp.getLoc(); Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, input, newInputType); - Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, filter, newFilterType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); @@ -988,7 +993,7 @@ auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newFilter}, + loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. @@ -996,20 +1001,31 @@ rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); + filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return success(); }; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; }; /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) /// dimensions into 1-D depthwise convolution ops. struct DownscaleDepthwiseConv2DNhwcHwcOp final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + DownscaleDepthwiseConv2DNhwcHwcOp( + MLIRContext *context, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + filter(filter) {} LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const override { - auto linalgOp = cast(*convOp); - if (linalgOp.hasBufferSemantics()) + if (failed(filter.checkAndNotify(rewriter, convOp))) + return failure(); + if (convOp.hasBufferSemantics()) return failure(); // To be implemented Value input = convOp.inputs().front(); @@ -1071,15 +1087,21 @@ rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); + filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return success(); }; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; }; } // namespace -void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit) { +void linalg::populateDecomposeConvolutionPatterns( + RewritePatternSet &patterns, LinalgTransformationFilter filter, + PatternBenefit benefit) { patterns.add(patterns.getContext(), + DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, benefit); }