diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -234,6 +234,10 @@ void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp, ArrayRef interchangeVector); +/// Creates a GenericOp from the given named operation `namedOp`. Assumes +/// `namedOp` is not a GenericOp and has a region builder. +GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp); + /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the /// smallest constant value for the size of the buffer needed for each @@ -380,6 +384,9 @@ interchangeGenericOpPrecondition(GenericOp genericOp, ArrayRef interchangeVector); +/// Generalize named operations to generic operations. +LogicalResult generalizeNamedOpPrecondition(Operation *op); + /// Promote std.subviews feeding linalg operations. LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options); @@ -701,6 +708,31 @@ SmallVector interchangeVector; }; +/// +/// Linalg generalization pattern. +/// +/// Apply the `generalization` transformation as a pattern. +/// `filter` controls LinalgTransformMarker matching and update when specified. +/// See `generalization` for more details. +struct LinalgGeneralizationPattern : public RewritePattern { + // Entry point to match any LinalgOp OpInterface. + LinalgGeneralizationPattern( + MLIRContext *context, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + // Entry point to match a specific Linalg op. + LinalgGeneralizationPattern( + StringRef opName, MLIRContext *context, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; +}; + /// /// Linalg promotion patterns. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -29,10 +29,19 @@ using namespace mlir; using namespace mlir::linalg; -// Creates a linalg.generic op from the given `namedOp`. Returns a null op if -// the given `namedOp` does not have a region builder. -static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, - PatternRewriter &rewriter) { +LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) { + LinalgOp namedOp = dyn_cast(op); + // Check if the operation is a LinalgOp but not a GenericOp. + if (!namedOp || isa(op)) + return failure(); + // Check if the operation has a region builder. + if (!namedOp.getRegionBuilder()) + return failure(); + return success(); +} + +GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter, + LinalgOp namedOp) { SmallVector inputOperands = namedOp.getInputOperands(); SmallVector outputOperands = namedOp.getOutputOperands(); SmallVector indexingMaps = namedOp.getIndexingMaps(); @@ -54,10 +63,7 @@ // Otherwise use the region builder to generate a new region. // TODO: Remove this path once all linag operations have a region attached. auto regionBuilder = namedOp.getRegionBuilder(); - if (!regionBuilder) { - LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); - return nullptr; - } + assert(regionBuilder && "expect the operation to have region builder"); return rewriter.create( namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, iterators, @@ -112,41 +118,6 @@ GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; }; -/// Catch-all pattern for converting all named ops with a region builder into -/// linalg.generic. -struct LinalgNamedOpGeneralizationPattern : RewritePattern { - LinalgNamedOpGeneralizationPattern(MLIRContext *context, - LinalgTransformationFilter marker, - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - marker(std::move(marker)) {} - - LogicalResult matchAndRewrite(Operation *rootOp, - PatternRewriter &rewriter) const override { - auto linalgOp = dyn_cast(rootOp); - if (!linalgOp) - return failure(); - if (failed(marker.checkAndNotify(rewriter, linalgOp))) - return failure(); - - // No nothing to do for linalg.generic. - if (isa(rootOp)) - return failure(); - - GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter); - if (!genericOp) - return failure(); - - rewriter.replaceOp(rootOp, genericOp.getResults()); - marker.replaceLinalgTransformationFilter(rewriter, - genericOp.getOperation()); - return success(); - } - -private: - LinalgTransformationFilter marker; -}; - struct LinalgGeneralizationPass : public LinalgGeneralizationBase { void runOnFunction() override; @@ -187,8 +158,7 @@ void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( RewritePatternSet &patterns, LinalgTransformationFilter marker) { - patterns.add(patterns.getContext(), - marker); + patterns.add(patterns.getContext(), marker); } std::unique_ptr> mlir::createLinalgGeneralizationPass() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -488,6 +488,30 @@ return success(); } +/// Linalg generic interchange pattern. +mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( + MLIRContext *context, LinalgTransformationFilter filter, + PatternBenefit benefit) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} + +mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( + StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, + PatternBenefit benefit) + : RewritePattern(opName, benefit, context, {}), filter(filter) {} + +LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + if (failed(generalizeNamedOpPrecondition(op))) + return failure(); + + GenericOp genericOp = generalizeNamedOp(rewriter, op); + rewriter.replaceOp(op, genericOp.getResults()); + filter.replaceLinalgTransformationFilter(rewriter, genericOp); + return success(); +} + mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( MLIRContext *context, LinalgTransformationFilter filter, LinalgPromotionOptions options, PatternBenefit benefit)