diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -27,24 +27,36 @@ #define DEBUG_TYPE "linalg-generalization" using namespace mlir; +using namespace mlir::linalg; // Creates a linalg.generic op from the given `namedOp`. Returns a null op if // the given `namedOp` does not have a region builder. -static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp, - OpBuilder &builder) { +static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, + PatternRewriter &rewriter) { + SmallVector indexingMaps = namedOp.getIndexingMaps(); + SmallVector iterators = llvm::to_vector<4>( + namedOp.iterator_types().getAsValueRange()); + SmallVector resultTypes = namedOp.getOutputTensorTypes(); + SmallVector types(resultTypes.begin(), resultTypes.end()); + + // Inline the existing region if the named operation has a region attached. + if (namedOp->getNumRegions() == 1) { + GenericOp genericOp = rewriter.create( + namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), + indexingMaps, iterators); + rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), + genericOp.region().begin()); + return genericOp; + } + + // Otherwise use the region builder to generate a new region. + // TODO: Remove this path once all linag operations have a region attached. auto regionBuilder = namedOp.getRegionBuilder(); if (!regionBuilder) { LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); return nullptr; } - - SmallVector indexingMaps = namedOp.getIndexingMaps(); - auto iterators = llvm::to_vector<4>( - namedOp.iterator_types().getAsValueRange()); - auto resultTypes = namedOp.getOutputTensorTypes(); - SmallVector types(resultTypes.begin(), resultTypes.end()); - - return builder.create( + return rewriter.create( namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { @@ -57,27 +69,27 @@ /// Base class for all linalg generalization patterns. A subclass must provide /// the following method: -/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &) +/// GenericOp createGenericOp(RootOp, PatternRewriter &) /// for creating the generic op. // TODO: remove this pattern after migrating all manually-written named ops // into auto-generated ones. template struct LinalgGeneralizationPattern : OpRewritePattern { LinalgGeneralizationPattern(MLIRContext *context, - linalg::LinalgTransformationFilter marker, + LinalgTransformationFilter marker, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), marker(std::move(marker)) {} LogicalResult matchAndRewrite(RootOp rootOp, PatternRewriter &rewriter) const override { - auto linalgOp = dyn_cast(rootOp.getOperation()); + auto linalgOp = dyn_cast(rootOp.getOperation()); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); auto *pattern = static_cast(this); - linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); + GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); if (!genericOp) return failure(); @@ -88,39 +100,38 @@ } private: - linalg::LinalgTransformationFilter marker; + LinalgTransformationFilter marker; }; struct GeneralizeConvOp - : public LinalgGeneralizationPattern { + : public LinalgGeneralizationPattern { using LinalgGeneralizationPattern::LinalgGeneralizationPattern; - linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const; + GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; }; /// Catch-all pattern for converting all named ops with a region builder into /// linalg.generic. struct LinalgNamedOpGeneralizationPattern : RewritePattern { LinalgNamedOpGeneralizationPattern(MLIRContext *context, - linalg::LinalgTransformationFilter marker, + LinalgTransformationFilter marker, PatternBenefit benefit = 1) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), marker(std::move(marker)) {} LogicalResult matchAndRewrite(Operation *rootOp, PatternRewriter &rewriter) const override { - auto linalgOp = dyn_cast(rootOp); + auto linalgOp = dyn_cast(rootOp); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); // No nothing to do for linalg.generic and linalg.indexed_generic. - if (isa(rootOp)) + if (isa(rootOp)) return failure(); - linalg::GenericOp genericOp = - createGenericOpFromNamedOp(linalgOp, rewriter); + GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter); if (!genericOp) return failure(); @@ -131,7 +142,7 @@ } private: - linalg::LinalgTransformationFilter marker; + LinalgTransformationFilter marker; }; struct LinalgGeneralizationPass @@ -144,17 +155,17 @@ void LinalgGeneralizationPass::runOnFunction() { FuncOp func = getFunction(); RewritePatternSet patterns(&getContext()); - linalg::populateLinalgConvGeneralizationPatterns(patterns); - linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns); + populateLinalgConvGeneralizationPatterns(patterns); + populateLinalgNamedOpsGeneralizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); } -linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp, - OpBuilder &builder) const { +GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, + OpBuilder &builder) const { SmallVector indexingMaps = convOp.getIndexingMaps(); auto iterators = llvm::to_vector<4>(convOp.iterator_types().getAsValueRange()); - return builder.create( + return builder.create( convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps, iterators, @@ -162,17 +173,17 @@ Value mul = bodyBuilder.create(bodyLoc, bodyArgs[0], bodyArgs[1]); Value add = bodyBuilder.create(bodyLoc, mul, bodyArgs[2]); - bodyBuilder.create(bodyLoc, add); + bodyBuilder.create(bodyLoc, add); }); } void mlir::linalg::populateLinalgConvGeneralizationPatterns( - RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) { + RewritePatternSet &patterns, LinalgTransformationFilter marker) { patterns.add(patterns.getContext(), marker); } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( - RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) { + RewritePatternSet &patterns, LinalgTransformationFilter marker) { patterns.add(patterns.getContext(), marker); } diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -441,3 +441,23 @@ // CHECK-NEXT: %[[CMP:.+]] = cmpf olt, %[[BBARG0]], %[[BBARG2]] : f32 // CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32 // CHECK-NEXT: linalg.yield %[[RES]] : f32 + +// ----- + +func @generalize_fill(%output: memref, %value : f32) { + linalg.fill(%output, %value) : memref, f32 + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: func @generalize_fill +// CHECK-SAME: (%[[ARG0:.+]]: memref, %[[VAL:.+]]: f32) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32) +// CHECK-NEXT: linalg.yield %[[VAL]] : f32