diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -16,19 +16,18 @@ let description = [{ ## Disclaimer - ** Proceed with care: not ready for general use. ** + **This dialect is actively developed and may change frequently.** - This dialect is evolving rapidly and may change on a very short notice. To - decrease the maintenance burden and churn, only a few in-tree use cases are - currently supported in the main tree: + To decrease the maintenance burden and churn, please post a description of + the intended use case on the MLIR forum. A few in-tree use cases are + currently supported: - high-level transformations on "structured ops" (i.e. ops that operate on chunks of data in a way that can be decomposed into operations on smaller chunks of data and control flow) in Linalg, Tensor and Vector - dialects. - - *Please post a description of the intended use case on the MLIR forum and - wait for confirmation.* + dialects; + - loop transformations in the SCF dialect. + ## Overview @@ -79,6 +78,18 @@ expected to have the `PossibleTopLevelTransformOpTrait` and may be used without arguments. + A program transformation expressed using the Transform dialect can be + programmatically triggered by calling: + + ```c++ + LogicalResult transform::applyTransforms(Operation *payloadRoot, + TransformOpInterface transform, + const TransformOptions &options); + ``` + + that applies the transformations specified by the top-level `transform` to + payload IR contained in `payloadRoot`. + ## Dialect Extension Mechanism This dialect is designed to be extensible, that is, clients of this dialect diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -206,6 +206,16 @@ bool expensiveChecksEnabled = true; }; +/// Entry point to the Transform dialect infrastructure. Applies the +/// transformation specified by `transform` to payload IR contained in +/// `payloadRoot`. The `transform` operation may contain other operations that +/// will be executed following the internal logic of the operation. It must +/// have the `PossibleTopLevelTransformOp` trait and not have any operands. +/// This function internally keeps track of the transformation state. +LogicalResult +applyTransforms(Operation *payloadRoot, TransformOpInterface transform, + const TransformOptions &options = TransformOptions()); + /// The state maintained across applications of various ops implementing the /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to @@ -250,15 +260,11 @@ TransformOpReverseMapping reverse; }; -public: - /// Creates a state for transform ops living in the given region. The parent - /// operation of the region. The second argument points to the root operation - /// in the payload IR being transformed, which may or may not contain the - /// region with transform ops. Additional options can be provided through the - /// trailing configuration object. - TransformState(Region ®ion, Operation *root, - const TransformOptions &options = TransformOptions()); + friend LogicalResult applyTransforms(Operation *payloadRoot, + TransformOpInterface transform, + const TransformOptions &options); +public: /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. Operation *getTopLevel() const; @@ -438,6 +444,13 @@ /// Identifier for storing top-level value in the `operations` mapping. static constexpr Value kTopLevelValue = Value(); + /// Creates a state for transform ops living in the given region. The second + /// argument points to the root operation in the payload IR being transformed, + /// which may or may not contain the region with transform ops. Additional + /// options can be provided through the trailing configuration object. + TransformState(Region *region, Operation *payloadRoot, + const TransformOptions &options = TransformOptions()); + /// Returns the mappings frame for the reigon in which the value is defined. const Mappings &getMapping(Value value) const { return const_cast(this)->getMapping(value); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Operation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "transform-dialect" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" @@ -25,14 +26,15 @@ constexpr const Value transform::TransformState::kTopLevelValue; -transform::TransformState::TransformState(Region ®ion, Operation *root, +transform::TransformState::TransformState(Region *region, + Operation *payloadRoot, const TransformOptions &options) - : topLevel(root), options(options) { - auto result = mappings.try_emplace(®ion); + : topLevel(payloadRoot), options(options) { + auto result = mappings.try_emplace(region); assert(result.second && "the region scope is already present"); (void)result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS - regionStack.push_back(®ion); + regionStack.push_back(region); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } @@ -447,6 +449,27 @@ effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); } +//===----------------------------------------------------------------------===// +// Entry point. +//===----------------------------------------------------------------------===// + +LogicalResult transform::applyTransforms(Operation *payloadRoot, + TransformOpInterface transform, + const TransformOptions &options) { +#ifndef NDEBUG + if (!transform->hasTrait() || + transform->getNumOperands() != 0) { + transform->emitError() + << "expected transform to start at the top-level transform op"; + llvm::report_fatal_error("could not run transforms", + /*gen_crash_diag=*/false); + } +#endif // NDEBUG + + TransformState state(transform->getParentRegion(), payloadRoot, options); + return state.applyTransform(transform).checkAndReport(); +} + //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1,29 +1,41 @@ // RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics -// expected-remark @below {{applying transformation}} -transform.test_transform_op +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-remark @below {{applying transformation}} + transform.test_transform_op +} // ----- -%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } -// expected-remark @below {{succeeded}} -transform.test_consume_operand_if_matches_param_or_fail %0[42] +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + // expected-remark @below {{succeeded}} + transform.test_consume_operand_if_matches_param_or_fail %0[42] +} // ----- -%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } -// expected-error @below {{expected the operand to be associated with 21 got 42}} -transform.test_consume_operand_if_matches_param_or_fail %0[21] +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + // expected-error @below {{expected the operand to be associated with 21 got 42}} + transform.test_consume_operand_if_matches_param_or_fail %0[21] +} // ----- // It is okay to have multiple handles to the same payload op as long // as only one of them is consumed. The expensive checks mode is necessary // to detect double-consumption. -%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } -%1 = transform.test_copy_payload %0 -// expected-remark @below {{succeeded}} -transform.test_consume_operand_if_matches_param_or_fail %0[42] +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + %1 = transform.test_copy_payload %0 + // expected-remark @below {{succeeded}} + transform.test_consume_operand_if_matches_param_or_fail %0[42] +} // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -41,13 +41,12 @@ void runOnOperation() override { ModuleOp module = getOperation(); - transform::TransformState state( - module.getBodyRegion(), module, - transform::TransformOptions().enableExpensiveChecks( - enableExpensiveChecks)); for (auto op : module.getBody()->getOps()) { - if (failed(state.applyTransform(op).checkAndReport())) + if (failed(transform::applyTransforms( + module, op, + transform::TransformOptions().enableExpensiveChecks( + enableExpensiveChecks)))) return signalPassFailure(); } }