diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -373,9 +373,10 @@ SymbolNameAttr:$sym_name, TypeAttrBase<"::mlir::FunctionType", "function type attribute">:$function_type, + OptionalAttr:$sym_visibility, OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs); - let regions = (region SizedRegion<1>:$body); + let regions = (region MaxSizedRegion<1>:$body); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -31,18 +31,22 @@ namespace transform { namespace detail { /// Template-free implementation of TransformInterpreterPassBase::initialize. -LogicalResult -interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName, - std::shared_ptr> &module); +LogicalResult interpreterBaseInitializeImpl( + MLIRContext *context, StringRef transformFileName, + StringRef transformLibraryFileName, + std::shared_ptr> &module, + std::shared_ptr> &libraryModule); /// Template-free implementation of /// TransformInterpreterPassBase::runOnOperation. LogicalResult interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, + const std::shared_ptr> &libraryModule, const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, + const Pass::Option &transformLibraryFileName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName); @@ -56,6 +60,9 @@ /// transform script. If empty, `debugTransformRootTag` is considered or the /// pass root operation must contain a single top-level transform op that /// will be interpreted. +/// - transformLibraryFileName: if non-empty, the name of the file containing +/// definitions of external symbols referenced in the transform script. +/// These definitions will be used to replace declarations. /// - debugPayloadRootTag: if non-empty, the value of the attribute named /// `kTransformDialectTagAttrName` indicating the single op that is /// considered the payload root of the transform interpreter; otherwise, the @@ -106,13 +113,17 @@ REQUIRE_PASS_OPTION(transformFileName); REQUIRE_PASS_OPTION(debugPayloadRootTag); REQUIRE_PASS_OPTION(debugTransformRootTag); + REQUIRE_PASS_OPTION(transformLibraryFileName); #undef REQUIRE_PASS_OPTION StringRef transformFileName = static_cast(this)->transformFileName; - return detail::interpreterBaseInitializeImpl(context, transformFileName, - sharedTransformModule); + StringRef transformLibraryFileName = + static_cast(this)->transformLibraryFileName; + return detail::interpreterBaseInitializeImpl( + context, transformFileName, transformLibraryFileName, + sharedTransformModule, transformLibraryModule); } /// Hook for passes to run additional logic in the pass before the @@ -132,9 +143,10 @@ if (failed(pass->runBeforeInterpreter(op)) || failed(detail::interpreterBaseRunOnOperationImpl( op, pass->getArgument(), sharedTransformModule, + transformLibraryModule, /*extraMappings=*/{}, options, pass->transformFileName, - pass->debugPayloadRootTag, pass->debugTransformRootTag, - binaryName)) || + pass->transformLibraryFileName, pass->debugPayloadRootTag, + pass->debugTransformRootTag, binaryName)) || failed(pass->runAfterInterpreter(op))) { return pass->signalPassFailure(); } @@ -150,12 +162,24 @@ return sharedTransformModule; } + /// Returns a read-only reference to the transform library module. + const std::shared_ptr> & + getTransformLibraryModule() const { + return transformLibraryModule; + } + private: /// The separate transform module to be used for transformations, shared /// across multiple instances of the pass if it is applied in parallel to /// avoid potentially expensive cloning. MUST NOT be modified after the pass /// has been initialized. std::shared_ptr> sharedTransformModule = nullptr; + + /// The transform module containing symbol definitions that become available + /// in the transform scripts. Similar to dynamic linking for binaries. This is + /// shared across multiple instances of the pass and therefore MUST NOT be + /// modified after the pass has been initialized. + std::shared_ptr> transformLibraryModule = nullptr; }; } // namespace transform diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -573,6 +573,9 @@ getOperation(), getTarget()); assert(callee && "unverified reference to unknown symbol"); + if (callee.isExternal()) + return emitDefiniteFailure() << "unresolved external named sequence"; + // Map operands to block arguments. SmallVector> mappings; detail::prepareValueMappings(mappings, getOperands(), state); @@ -648,7 +651,10 @@ } // Carry over effects from the callee. - remapArgumentEffects(callee.getBody().front(), getOperands(), effects); + // TODO: external callees must provides attributes annotating the + // readonly/consume effects on operands. + if (!callee.isExternal()) + remapArgumentEffects(callee.getBody().front(), getOperands(), effects); // Proper effects. onlyReadsHandle(getOperands(), effects); @@ -784,9 +790,6 @@ /// verifier runs, e.g., during trait verification. static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op) { - if (op.isExternal()) - return emitSilenceableFailure(op) << "cannot be empty"; - if (Operation *parent = op->getParentWithTrait()) { if (!parent->getAttr( transform::TransformDialect::kWithNamedSequenceAttrName)) { @@ -808,6 +811,9 @@ return diag; } + if (op.isExternal() || op.getBody().empty()) + return DiagnosedSilenceableFailure::success(); + if (op.getBody().front().empty()) return emitSilenceableFailure(op) << "expected a non-empty body block"; diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -14,10 +14,13 @@ #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" @@ -157,9 +160,17 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, + const Pass::Option &transformLibraryFileName, StringRef binaryName) { + std::string transformLibraryOption = ""; + if (!transformLibraryFileName.empty()) { + transformLibraryOption = + llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(), + transformLibraryFileName.getValue()) + .str(); + } os << llvm::formatv( - "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, + "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName, passName, debugPayloadRootTag.getArgStr(), debugPayloadRootTag.empty() ? StringRef(kTransformDialectTagPayloadRootValue) @@ -168,7 +179,7 @@ debugTransformRootTag.empty() ? StringRef(kTransformDialectTagTransformContainerValue) : debugTransformRootTag, - binaryName); + transformLibraryOption, binaryName); return os; } @@ -184,11 +195,12 @@ /// Saves the payload and the transform IR into a temporary file and reports /// the file name to `os`. -void saveReproToTempFile(llvm::raw_ostream &os, Operation *target, - Operation *transform, StringRef passName, - const Pass::Option &debugPayloadRootTag, - const Pass::Option &debugTransformRootTag, - StringRef binaryName) { +void saveReproToTempFile( + llvm::raw_ostream &os, Operation *target, Operation *transform, + StringRef passName, const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + const Pass::Option &transformLibraryFileName, + StringRef binaryName) { using llvm::sys::fs::TempFile; Operation *root = getRootOperation(target); @@ -213,7 +225,8 @@ os << "=== Transform Interpreter Repro ===\n"; printReproCall(os, root->getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, binaryName) + debugPayloadRootTag, debugTransformRootTag, + transformLibraryFileName, binaryName) << " " << filename << "\n"; os << "===================================\n"; } @@ -224,6 +237,7 @@ Operation *target, Operation *transform, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, + const Pass::Option &transformLibraryFileName, StringRef binaryName) { MLIRContext *context = target->getContext(); @@ -266,7 +280,8 @@ llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; printReproCall(llvm::dbgs() << "cat <getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, binaryName) + debugPayloadRootTag, debugTransformRootTag, + transformLibraryFileName, binaryName) << "\n"; printModuleForRepro(llvm::dbgs(), root, transform); llvm::dbgs() << "\nEOF\n"; @@ -275,16 +290,63 @@ (void)root; DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { saveReproToTempFile(llvm::dbgs(), target, transform, passName, - debugPayloadRootTag, debugTransformRootTag, binaryName); + debugPayloadRootTag, debugTransformRootTag, + transformLibraryFileName, binaryName); }); } +/// Replaces external symbols in `block` with their (non-external) definitions +/// from the given module. +static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { + for (Operation &op : llvm::make_early_inc_range(block)) { + LLVM_DEBUG(DBGS() << op << "\n"); + auto symbol = dyn_cast(op); + if (!symbol) + continue; + if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()) + continue; + + LLVM_DEBUG(DBGS() << "looking for definition of symbol " + << symbol.getNameAttr() << ":"); + SymbolTable symbolTable(definitions); + Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr()); + if (!externalSymbol || externalSymbol->getNumRegions() != 1 || + externalSymbol->getRegion(0).empty()) { + LLVM_DEBUG(llvm::dbgs() << "not found\n"); + continue; + } + + auto symbolFunc = dyn_cast(op); + auto externalSymbolFunc = dyn_cast(externalSymbol); + if (!symbolFunc || !externalSymbolFunc) { + LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n"); + if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) { + return symbolFunc.emitError() + << "external definition has a mismatching signature (" + << externalSymbolFunc.getFunctionType() << ")"; + } + + OpBuilder builder(&op); + builder.setInsertionPoint(&op); + builder.clone(*externalSymbol); + symbol->erase(); + } + + return success(); +} + LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, + const std::shared_ptr> &libraryModule, const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, + const Pass::Option &transformLibraryFileName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName) { @@ -328,13 +390,31 @@ // Step 3 // ------ + // Copy external defintions for symbols if provided. Be aware of potential + // concurrent execution (normally, the error shouldn't be triggered unless the + // transform IR modifies itself in a pass, which is also forbidden elsewhere). + if (!sharedTransform && libraryModule && *libraryModule) { + if (!target->isProperAncestor(transformRoot)) { + InFlightDiagnostic diag = + transformRoot->emitError() + << "cannot inject transform definitions next to pass anchor op"; + diag.attachNote(target->getLoc()) << "pass anchor op"; + return diag; + } + if (failed(defineDeclaredSymbols(*transformRoot->getBlock(), + libraryModule->get()))) + return failure(); + } + + // Step 4 + // ------ // Optionally perform debug actions requested by the user to dump IR and a // repro to stderr and/or a file. performOptionalDebugActions(target, transformRoot, passName, debugPayloadRootTag, debugTransformRootTag, - binaryName); + transformLibraryFileName, binaryName); - // Step 4 + // Step 5 // ------ // Apply the transform to the IR return applyTransforms(payloadRoot, cast(transformRoot), @@ -343,11 +423,33 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, - std::shared_ptr> &module) { + StringRef transformLibraryFileName, + std::shared_ptr> &module, + std::shared_ptr> &libraryModule) { OwningOpRef parsed; if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) return failure(); + if (parsed && failed(mlir::verify(*parsed))) + return failure(); + + OwningOpRef parsedLibrary; + if (failed(parseTransformModuleFromFile(context, transformLibraryFileName, + parsedLibrary))) + return failure(); + if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) + return failure(); module = std::make_shared>(std::move(parsed)); + if (!parsedLibrary || !*parsedLibrary) + return success(); + + if (module && *module) { + if (failed(defineDeclaredSymbols(*module->get().getBody(), + parsedLibrary.get()))) + return failure(); + } else { + libraryModule = + std::make_shared>(std::move(parsedLibrary)); + } return success(); } diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -293,8 +293,10 @@ // ----- module attributes { transform.with_named_sequence } { - // expected-error @below {{failed to verify constraint: region with 1 blocks}} - "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> () + // expected-error @below {{expected a non-empty body block}} + "transform.named_sequence"() ({ + ^bb0: + }) { sym_name = "external_named_sequence", function_type = () -> () } : () -> () transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics + +// The external transform script has a declaration to the named sequence @foo, +// the definition of which is provided in another file. Repeated application +// of the same pass should not be a problem. Note that the same diagnostic +// produced twice at the same location only needs to be matched once. + +// expected-remark @below {{message}} +module {} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: --verify-diagnostics --split-input-file + +// The definition of the @foo named sequence is provided in another file. It +// will be included because of the pass option. + +module attributes {transform.with_named_sequence} { + // expected-error @below {{external definition has a mismatching signature}} + transform.named_sequence private @foo(!transform.op<"builtin.module">) + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.op<"builtin.module">): + include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> () + } +} + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence private @undefined_sequence() + + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{unresolved external named sequence}} + include @undefined_sequence failures(suppress) () : () -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: --verify-diagnostics | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics | FileCheck %s + +// The definition of the @foo named sequence is provided in another file. It +// will be included because of the pass option. Repeated application of the +// same pass, with or without the library option, should not be a problem. +// Note that the same diagnostic produced twice at the same location only +// needs to be matched once. + +// expected-remark @below {{message}} +module attributes {transform.with_named_sequence} { + // CHECK: transform.named_sequence @foo + // CHECK: test_print_remark_at_operand %{{.*}}, "message" + transform.named_sequence private @foo(!transform.any_op) + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s + +module attributes {transform.with_named_sequence} { + transform.named_sequence @foo(%arg0: !transform.any_op) { + transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -138,7 +138,8 @@ options = options.enableExpensiveChecks(enableExpensiveChecks); if (failed(transform::detail::interpreterBaseRunOnOperationImpl( getOperation(), getArgument(), getSharedTransformModule(), - extraMapping, options, transformFileName, debugPayloadRootTag, + getTransformLibraryModule(), extraMapping, options, + transformFileName, transformLibraryFileName, debugPayloadRootTag, debugTransformRootTag, getBinaryName()))) return signalPassFailure(); } @@ -193,6 +194,11 @@ "the given value as container IR for top-level transform ops. This " "allows user control on what transformation to apply. If empty, " "select the container of the top-level transform op.")}; + Option transformLibraryFileName{ + *this, "transform-library-file-name", llvm::cl::init(""), + llvm::cl::desc( + "Optional name of the file containing transform dialect symbol " + "definitions to be injected into the transform module.")}; }; struct TestTransformDialectEraseSchedulePass diff --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel @@ -15,10 +15,15 @@ "//mlir/test:lit_data", ] + glob([ "Transform/*-source.mlir", + "Transform/*-symbol-def.mlir", ]) ) for src in glob( include=["**/*.mlir"], - exclude=["Transform/*-source.mlir"] + exclude=[ + "Transform/*-source.mlir", + "Transform/*-symbol-def.mlir", + "Transform/*-symbol-decl-and-schedule.mlir", + ] ) ]