diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,9 +19,9 @@ namespace mlir { -std::unique_ptr> createConvertElementwiseToLinalgPass(); +std::unique_ptr createConvertElementwiseToLinalgPass(); -std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); +std::unique_ptr createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgElementwiseOpFusionPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,12 +11,15 @@ include "mlir/Pass/PassBase.td" -def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> { +def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> { let summary = "Convert ElementwiseMappable ops to linalg"; let description = [{ Convert ops with the `ElementwiseMappable` trait to linalg parallel loops. This pass only converts ops that operate on ranked tensors. + + This can be run on any op with the FunctionLike trait and must not be + run on others. }]; let constructor = "mlir::createConvertElementwiseToLinalgPass()"; let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; @@ -54,8 +57,12 @@ let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } -def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { +def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; + let description = [{ + This can be run on any op with the FunctionLike trait and must not be + run on others. + }]; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; let options = [ Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", @@ -197,7 +204,7 @@ let dependentDialects = ["linalg::LinalgDialect"]; } -def LinalgDetensorize : FunctionPass<"linalg-detensorize"> { +def LinalgDetensorize : Pass<"linalg-detensorize", ""> { let summary = "Detensorize linalg ops"; let constructor = "mlir::createLinalgDetensorizePass()"; let dependentDialects = []; @@ -222,6 +229,9 @@ In addition to detensoring individual ops, this pass detensors internal control flow inside a function. All blocks except for the entry block are detensored by converting their arguments whenever possible. + + This can be run on any op with the FunctionLike trait and must not be + run on others. }]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -93,10 +93,11 @@ /// A conversion pattern for detensoring internal (non-entry) blocks within a /// function. struct FunctionNonEntryBlockConversion : public ConversionPattern { - FunctionNonEntryBlockConversion(StringRef functionLikeOpName, - MLIRContext *ctx, TypeConverter &converter, + FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, DenseSet blockArgsToDetensor) - : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), + : ConversionPattern(converter, MatchTraitOpTypeTag(), + TypeID::get(), /*benefit=*/1, + ctx), blockArgsToDetensor(blockArgsToDetensor) {} LogicalResult @@ -235,7 +236,8 @@ /// detensored, then: /// - opsToDetensor should be = {linalg.generic{add}}. /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. - virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + virtual void compute(Operation *func, + DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) = 0; @@ -286,18 +288,18 @@ /// AND can be detensored. class ControlFlowDetectionModel : public CostModel { public: - void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + void compute(Operation *func, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { SmallVector workList; - func.walk([&](CondBranchOp condBr) { + func->walk([&](CondBranchOp condBr) { for (auto operand : condBr.getOperands()) { workList.push_back(operand); } }); - func.walk([&](BranchOp br) { + func->walk([&](BranchOp br) { for (auto operand : br.getOperands()) { workList.push_back(operand); } @@ -491,21 +493,24 @@ /// Detensorize everything that can detensored. class AggressiveDetensoringModel : public CostModel { public: - void compute(FuncOp func, DetensorizeTypeConverter typeConverter, + void compute(Operation *func, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { - func.walk([&](GenericOp genericOp) { + func->walk([&](GenericOp genericOp) { if (shouldBeDetensored(genericOp, typeConverter)) opsToDetensor.insert(genericOp); }); - for (Block &block : llvm::drop_begin(func.getBody(), 1)) + for (Block &block : + llvm::drop_begin(function_like_impl::getFunctionBody(func), 1)) for (BlockArgument blockArgument : block.getArguments()) blockArgsToDetensor.insert(blockArgument); } }; - void runOnFunction() override { + void runOnOperation() override { + assert(getOperation()->hasTrait() && + "DetensorizePass can only be run on FunctionLike operations"); MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; RewritePatternSet patterns(context); @@ -516,12 +521,12 @@ if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; - costModel.compute(getFunction(), typeConverter, opsToDetensor, + costModel.compute(getOperation(), typeConverter, opsToDetensor, blockArgsToDetensor); } else { ControlFlowDetectionModel costModel; - costModel.compute(getFunction(), typeConverter, opsToDetensor, + costModel.compute(getOperation(), typeConverter, opsToDetensor, blockArgsToDetensor); } @@ -531,24 +536,26 @@ target.addDynamicallyLegalOp( [&](GenericOp op) { return !opsToDetensor.count(op); }); - target.addDynamicallyLegalOp([&](FuncOp op) { + target.markUnknownOpDynamicallyLegal([&](Operation *op) { // A function is legal if all of its non-entry blocks are legal. We // don't legalize the entry block (i.e. the function's signature) // since detensoring can't happen along external calling convention // boundaries, which we conservatively approximate as all function // signatures. - return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { - if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) { - return blockArgument.getOwner() == &block && - !typeConverter.isLegal(blockArgument.getType()); - })) { - return false; - } - return true; - }); - }); + if (op->hasTrait()) { + auto &body = function_like_impl::getFunctionBody(op); + return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { + if (llvm::any_of( + blockArgsToDetensor, [&](BlockArgument blockArgument) { + return blockArgument.getOwner() == &block && + !typeConverter.isLegal(blockArgument.getType()); + })) { + return false; + } + return true; + }); + } - target.markUnknownOpDynamicallyLegal([&](Operation *op) { if (isNotBranchOpInterfaceOrReturnLikeOp(op) || isLegalForReturnOpTypeConversionPattern(op, typeConverter, /*returnOpAlwaysLegal*/ true)) @@ -570,8 +577,7 @@ }); patterns.insert(typeConverter, context); - patterns.insert(FuncOp::getOperationName(), - context, typeConverter, + patterns.insert(context, typeConverter, blockArgsToDetensor); // Since non-entry block arguments get detensorized, we also need to // update the control flow inside the function to reflect the correct @@ -585,12 +591,13 @@ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); RewritePatternSet canonPatterns(context); canonPatterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getFunction(), + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -558,20 +558,23 @@ /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass : public LinalgFoldUnitExtentDimsBase { - void runOnFunction() override { - FuncOp funcOp = getFunction(); - MLIRContext *context = funcOp.getContext(); + void runOnOperation() override { + auto funcOp = getOperation(); + assert(funcOp->hasTrait() && + "LinalgFoldUnitExtentDimsPass can only be run on FunctionLike " + "operations"); + MLIRContext *context = funcOp->getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) patterns.add(context); else populateFoldUnitExtentDimsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + (void)applyPatternsAndFoldGreedily( + function_like_impl::getFunctionBody(funcOp), std::move(patterns)); } }; } // namespace -std::unique_ptr> -mlir::createLinalgFoldUnitExtentDimsPass() { +std::unique_ptr mlir::createLinalgFoldUnitExtentDimsPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -126,8 +126,12 @@ class ConvertElementwiseToLinalgPass : public ConvertElementwiseToLinalgBase { - void runOnFunction() final { + void runOnOperation() final { auto func = getOperation(); + assert(func->hasTrait() && + "ConvertElementwiseToLinalgPass can only be run on FunctionLike " + "operations"); + auto *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); @@ -143,7 +147,6 @@ }; } // namespace -std::unique_ptr> -mlir::createConvertElementwiseToLinalgPass() { +std::unique_ptr mlir::createConvertElementwiseToLinalgPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s +// RUN: mlir-opt -pass-pipeline="builtin.func(convert-elementwise-to-linalg)" -split-input-file %s | FileCheck %s // In-depth checking of the linalg.generic op for a very trivial case. // CHECK: #[[$MAP:.*]] = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.func(linalg-detensorize{aggressive-mode})" | FileCheck %s #map = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir --- a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s // TODO: Detensoring breaks if %arg0 or %arg1 are passed directly as tensors. Fix that. func @if_true_test(%arg0: i1, %arg1: i32) -> tensor attributes {} { diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s #map0 = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL -// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize{aggressive-mode})" | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s -check-prefix=DET-CF #map0 = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL -// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize{aggressive-mode})" | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s -check-prefix=DET-CF #map0 = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL -// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize{aggressive-mode})" | FileCheck %s -check-prefix=DET-ALL +// RUN: mlir-opt %s -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s -check-prefix=DET-CF #map0 = affine_map<() -> ()> #map1 = affine_map<(i) -> ()> diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.func(linalg-detensorize)" | FileCheck %s #map0 = affine_map<() -> ()> diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -linalg-fold-unit-extent-dims | FileCheck %s +// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.func(linalg-fold-unit-extent-dims)" | FileCheck %s #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir --- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir +++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.func(linalg-fold-unit-extent-dims{fold-one-trip-loops-only})" | FileCheck %s #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>,