diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -35,7 +35,9 @@ MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, std::shared_ptr> &module, - std::shared_ptr> &libraryModule); + std::shared_ptr> &libraryModule, + function_ref(OpBuilder &, Location)> + moduleBuilder = nullptr); /// Template-free implementation of /// TransformInterpreterPassBase::runOnOperation. @@ -123,7 +125,11 @@ static_cast(this)->transformLibraryFileName; return detail::interpreterBaseInitializeImpl( context, transformFileName, transformLibraryFileName, - sharedTransformModule, transformLibraryModule); + sharedTransformModule, transformLibraryModule, + [this](OpBuilder &builder, Location loc) { + return static_cast(this)->constructTransformModule( + builder, loc); + }); } /// Hook for passes to run additional logic in the pass before the @@ -136,6 +142,14 @@ /// fails. LogicalResult runAfterInterpreter(Operation *) { return success(); } + /// Hook for passes to run custom logic to construct the transform module. + /// This will run during initialization. If the external script is provided, + /// it overrides the construction, which will not be called. + std::optional constructTransformModule(OpBuilder &builder, + Location loc) { + return std::nullopt; + } + void runOnOperation() override { auto *pass = static_cast(this); Operation *op = pass->getOperation(); diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -462,7 +462,9 @@ MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, std::shared_ptr> &module, - std::shared_ptr> &libraryModule) { + std::shared_ptr> &libraryModule, + function_ref(OpBuilder &, Location)> + moduleBuilder) { OwningOpRef parsed; if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) return failure(); @@ -476,7 +478,23 @@ if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) return failure(); - module = std::make_shared>(std::move(parsed)); + if (parsed) { + module = std::make_shared>(std::move(parsed)); + } else if (moduleBuilder) { + // TODO: better location story. + auto location = UnknownLoc::get(context); + auto localModule = std::make_shared>( + ModuleOp::create(location, "__transform")); + + OpBuilder b(context); + b.setInsertionPointToEnd(localModule->get().getBody()); + if (std::optional result = moduleBuilder(b, location)) { + if (failed(*result)) + return failure(); + module = std::move(localModule); + } + } + if (!parsedLibrary || !*parsedLibrary) return success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter=test-module-generation=1 --verify-diagnostics + +// expected-remark @below {{remark from generated}} +module {} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -11,7 +11,9 @@ // //===----------------------------------------------------------------------===// +#include "TestTransformDialectExtension.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -46,6 +48,10 @@ return "apply transform dialect operations one by one"; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void findOperationsByName(Operation *root, StringRef name, SmallVectorImpl &operations) { root->walk([&](Operation *op) { @@ -86,6 +92,22 @@ return numSetValues; } + std::optional constructTransformModule(OpBuilder &builder, + Location loc) { + if (!testModuleGeneration) + return std::nullopt; + + builder.create( + loc, TypeRange(), transform::FailurePropagationMode::Propagate, + builder.getType(), + [](OpBuilder &b, Location nested, Value rootH) { + b.create( + nested, rootH, "remark from generated"); + b.create(nested, ValueRange()); + }); + return success(); + } + void runOnOperation() override { unsigned firstSetOptions = numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, @@ -199,6 +221,11 @@ llvm::cl::desc( "Optional name of the file containing transform dialect symbol " "definitions to be injected into the transform module.")}; + + Option testModuleGeneration{ + *this, "test-module-generation", llvm::cl::init(false), + llvm::cl::desc("test the generation of the transform module during pass " + "initialization, overridden by parsing")}; }; struct TestTransformDialectEraseSchedulePass