diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -27,5 +27,6 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Transform) add_subdirectory(Vector) add_subdirectory(X86Vector) diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +# The dialect does not have its own ops, so just generate the dialect files. +set(LLVM_TARGET_DEFINITIONS TransformDialect.td) +mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform) +mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform) +add_public_tablegen_target(MLIRTransformDialectIncGen) +add_dependencies(mlir-headers MLIRTransformDialectIncGen) + +add_mlir_interface(TransformInterfaces) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -0,0 +1,99 @@ +//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" + +namespace mlir { +namespace transform { + +#ifndef NDEBUG +namespace detail { +/// Asserts that the operations provided as template arguments implement the +/// TransformOpInterface. This must be a dynamic assertion since interface +/// implementations may be registered at runtime. +template +static inline void checkImplementsTransformInterface(MLIRContext *context) { + // Since the operation is being inserted into the Transform dialect and the + // dialect does not implement the interface fallback, only check for the op + // itself having the interface implementation. + RegisteredOperationName opName = + *RegisteredOperationName::lookup(OpTy::getOperationName(), context); + assert(opName.hasInterface() && + "ops injected into the transform dialect must implement " + "TransformOpInterface"); +} +} // namespace detail +#endif // NDEBUG + +/// Base class for extensions of the Transform dialect that supports injecting +/// operations into the Transform dialect at load time. Concrete extensions are +/// expected to derive this class and register operations in the constructor. +/// They can be registered with the DialectRegistry and automatically applied +/// to the Transform dialect when it is loaded. +template +class TransformDialectExtension + : public DialectExtension { + using Initializer = std::function; + using DialectLoader = std::function; + +public: + /// Extension application hook. Actually loads the dependent dialects and + /// registers the additional operations. Not expected to be called directly. + void apply(MLIRContext *context, TransformDialect *transformDialect, + ExtraDialects *...) const final { + for (const DialectLoader &loader : dialectLoaders) + loader(context); + for (const Initializer &init : opInitializers) + init(transformDialect); + } + +protected: + /// Injects the operations into the Transform dialect. The operations must + /// implement the TransformOpInterface and the implementation must be already + /// available when the operation is injected. + template + void registerTransformOps() { + opInitializers.push_back([](TransformDialect *transformDialect) { + transformDialect->addOperations(); + +#ifndef NDEBUG + std::initializer_list{ + (detail::checkImplementsTransformInterface( + transformDialect->getContext()), + 0)...}; +#endif // NDEBUG + }); + } + + /// Declares that this Transform dialect extension depends on the dialect + /// provided as template parameter. When the Transform dialect is loaded, + /// dependent dialects will be loaded as well. This is intended for dialects + /// that contain attributes and types used in creation and canonicalization of + /// the injected operations. + template + void declareDependentDialect() { + dialectLoaders.push_back( + [](MLIRContext *context) { context->loadDialect(); }); + } + +private: + SmallVector opInitializers; + SmallVector dialectLoaders; +}; + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -0,0 +1,175 @@ +//===- TransformDialect.td - Transform dialect definition --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT + +include "mlir/IR/OpBase.td" + +def Transform_Dialect : Dialect { + let summary = "Fine-grain transformation control dialect"; + let description = [{ + ## Disclaimer + + ** Proceed with care: not ready for general use. ** + + This dialect is evolving rapidly and may change on a very short notice. To + decrease the maintenance burden and churn, only a few in-tree use cases are + currently supported in the main tree: + + - high-level transformations on "structured ops" (i.e. ops that operate on + chunks of data in a way that can be decomposed into operations on + smaller chunks of data and control flow) in Linalg, Tensor and Vector + dialects. + + *Please post a description of the intended use case on the MLIR forum and + wait for confirmation.* + + ## Overview + + This dialect provides operations that can be used to control transformation + of the IR using a different portion of the IR. It refers to the IR being + transformed as payload IR, and to the IR guiding the transformation as + transform IR. + + The main use case for this dialect is orchestrating fine-grain + transformations on individual operations or sets thereof. For example, it + may involve finding loop-like operations with specific properties (e.g., + large size) in the payload IR, applying loop tiling to those and only those + operations, and then applying loop unrolling to the inner loops produced + by the previous transformations. As such, it is not intended as a + replacement for the pass infrastructure, nor for the pattern rewriting + infrastructure. In the most common case, the transform IR will be processed + and applied to the payload IR by a pass. Transformations expressed by the + transform dialect may be implemented using the pattern infrastructure or any + other relevant MLIR component. + + The following IR gives a rough idea of what the operations in this dialect + may look like: + + ```mlir + %0 = transform.loop.find { size > 42 } + %1:2 = transform.loop.tile { tile_sizes = [2,3,4] } + transform.loop.unroll %1#1 + ``` + + The values defined by operations in this dialect correspond to (groups of) + operations in the payload IR. In the example above, `%0` corresponds to the + set of loops found in the payload IR that satisfy the condition, and `%1` + correspond to groups of outer and inner loops, respectively, produced by + the tiling transformation. + + This dialect is designed to be extensible, that is, clients of this dialect + are allowed to inject additional operations into this dialect using the + `TransformDialectExtension` mechanism. This allows the dialect to avoid a + dependency on the implementation of the transformation as well as to avoid + introducing dialect-specific transform dialects. In the example above, + the operations may have been injected by a notional `loop` dialect rather + than defined in this dialect, hence the common prefix. + + It is recommended to prefix injected operations with one or several + dot-separated words that indicate which extension adds them. For + dialect-specific transformations, the prefix is naturally the name of the + dialect, e.g., `transform.affine.reschedule`. For dialect-agnostic + transformations (typically implemented using interfaces), the prefix may + be derived from the interface name or from a common concept, e.g., + `transform.loop.tile` may apply to any loop-like operation that implements + `TileableOpInterface`. The C++ classes for the dialect extension should + include the prefix in their name, e.g., `AffineTransformDialectExtension` or + `LoopTransformDialectExtension` in the cases above. Unprefixed operation + names are reserved for ops defined directly in the Transform dialect. + + ## Intended Use and Integrations + + The transformation control infrastructure provided by this dialect is + positioned roughly between rewrite patterns and passes. A transformation + that is executed by a transform operation is likely to be sufficiently + complex to require at least a set of patterns to be implemented. It is also + expected to be more focused than a pass: a pass typically applies identical + transformations everywhere in the IR, a transform dialect-controlled + transformation would apply to a small subset of operations selected, e.g., + by a pattern-matching operation or generated by a previous transformation. + It is discouraged, although technically possible, to run a pass pipeline as + part of the transform op implementation. + + One of the main scenarios for using this dialect is fine-grain chaining of + transformations. For example, a loop-like operation may see its iteration + domain split into two parts, implemented as separate loops (transformation + known as index-set splitting), each of which is then transformed differently + (e.g., the first loop is tiled and the second unrolled) with the necessary + enabling and cleanup patterns around the main transformation: + + ```mlir + // + // ... + %parts:2 = transform.loop.split %loop { upper_bound_divisible_by = 8 } + transform.loop.tile %parts#0 { tile_sizes = [8] } + transform.loop.unroll %parts#1 { full } + ``` + + This composition would have been difficult to implement as separate passes + since the hypothetical "tiling" and "unrolling" pass would need to somehow + differentiate between the parts of the loop produced by the previous pass + (both are the same operation, and it is likely undesirable to pollute the + operation with pass-specific information). Implementing passes that run the + combined transfomration would have run into the combinatorial explosion + issue due to multiple possible transform compositions or into the need for + deep pass parameterization, the ultimate form of which is an ad-hoc dialect + to specify which transformations the pass should run. The transform dialect + provides a uniform, extensible mechanism for controlling transformations in + such cases. + + The transform dialect is supposed to be consumed by an "interpreter" pass + that drives the application of transformations. To ensure extensibility and + composability, this pass is not expected to actually perform the + transformations specified by the ops. Instead, the transformations are + implemented by the transform ops themselves via `TransformOpInterface`. The + pass serves as the entry point, handles the flow of transform operations and + takes care of bookkeeping. As such, the transform dialect does not provide + the interpreter pass. Instead, it provides a set of utilities that can be + used by clients to define their own interpreter passes or as part of a more + complex pass. For example, the mapping between values in the tranfsorm IR + and operations in the payload IR, or the function that applies the + transformations specified by ops in the given block sequentially. Note that + a transform op may have regions with further transform ops in them, with + the op itself guiding how to dispatch the transformation control flow to + those regions. This approach allows clients to decide on the relative + location of the transform IR in their input (e.g., nested modules, separate + modules, optional regions to certain operations, etc.), register additional + transform operations and perform client-specific bookkeeping. + + ## Effects on the Infrastructure + + Although scoped to a single dialect, this functionality conceptually belongs + to the MLIR infrastructure. It aims to be minimally intrusive and opt-in. + + Some infrastructural components may grow extra functionality to support the + transform dialect. In particular, the pattern infrastructure may add extra + hooks to identify the "main results" of a transformation or to notify + external observers about changes made to certain operations. These are not + expected to affect the existing uses of the infrastructure. + + For the sake of reusability, transformations should be implemented as + utility functions that are called from the interface methods of transform + ops rather than having the methods directly act on the payload IR. + }]; + + let name = "transform"; + let cppNamespace = "::mlir::transform"; + + let extraClassDeclaration = [{ + // Make addOperations available to the TransformDialectExtension class. + private: + using ::mlir::Dialect::addOperations; + + template + friend class TransformDialectExtension; + }]; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -0,0 +1,131 @@ +//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace transform { + +class TransformOpInterface; + +/// The state maintained across applications of various ops implementing the +/// TransformOpInterface. The operations implementing this interface and the +/// surrounding structure are referred to as transform IR. The operations to +/// which transformations apply are referred to as payload IR. The state thus +/// contains the mapping between values defined in the transform IR ops and +/// payload IR ops. It assumes that each value in the transform IR can be used +/// at most once (since transformations are likely to change the payload IR ops +/// the value corresponds to). Checks that transform IR values correspond to +/// disjoint sets of payload IR ops throughout the transformation. +/// +/// A reference to this class is passed as an argument to "apply" methods of the +/// transform op interface. Thus the "apply" method can call +/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations +/// associated with its operand and subject to transformation. The method is +/// expected to populate the `TransformResults` class instance in order to +/// update the mapping. The `applyTransform` method takes care of propagating +/// the state of `TransformResults` into the instance of this class. +class TransformState { + /// Mapping between a Value in the transform IR and the corresponding set of + /// operations in the payload IR. + using TransformOpMapping = DenseMap>; + + /// Mapping between a payload IR operation and the transform IR value it is + /// currently associated with. + using TransformOpReverseMapping = DenseMap; + +public: + /// Creates a state for the transformation rooted at the given op. + explicit TransformState(Operation *root); + + /// Returns the op at which the transformation state is rooted. This is + /// typically helpful for transformations that apply globally. + Operation *getTopLevel() const; + + /// Returns the list of ops that the given transform IR value corresponds to. + /// This is helpful for transformations that apply to a particular handle. + ArrayRef getPayloadOps(Value value) const; + + /// Applies the transformation specified by the given transform op and updates + /// the state accordingly. + LogicalResult applyTransform(TransformOpInterface transform); + +private: + /// Identifier for storing top-level value in the `operations` mapping. + static constexpr Value kTopLevelValue = Value(); + + /// Sets the payload IR ops associated with the given transform IR value. + /// Fails if this would result in multiple transform IR values with uses + /// corresponding to the same payload IR ops. For example, a hypothetical + /// "find function by name" transform op would (indirectly) call this + /// function for its result. Having two such calls in a row with for different + /// values, e.g. coming from different ops: + /// + /// %0 = transform.find_func_by_name { name = "myfunc" } + /// %1 = transform.find_func_by_name { name = "myfunc" } + /// + /// would lead to both values pointing to the same operation. The second call + /// to setPayloadOps will fail, unless the association with the %0 value is + /// removed first by calling update/removePayloadOps. + LogicalResult setPayloadOps(Value value, ArrayRef targets); + + /// Forgets the payload IR ops associated with the given transform IR value. + void removePayloadOps(Value value); + + /// Updates the payload IR ops associated with the given transform IR value. + /// The callback function is called once per associated operation and is + /// expected to return the modified operation or nullptr. In the latter case, + /// the corresponding operation is no longer associated with the transform IR + /// value. + void updatePayloadOps(Value value, + function_ref callback); + + /// The mapping between payload IR values and transform IR ops. + TransformOpMapping operationMapping; + TransformOpReverseMapping reverseMapping; +}; + +/// Local mapping between values defined by a specific op implementing the +/// TransformOpInterface and the payload IR ops they correspond to. +class TransformResults { + friend class TransformState; + +public: + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given list of payload IR ops. Each result must be set + /// by the transformation exactly once. + void set(OpResult value, ArrayRef ops); + +private: + /// Creates an instance of TransformResults that expects mappings for + /// `numSegments` values. + explicit TransformResults(unsigned numSegments); + + /// Gets the list of operations associated with the result identified by its + /// number in the list of operation results. + ArrayRef get(unsigned resultNumber) const; + + /// Storage for pointers to payload IR ops that are associated with results of + /// a transform IR op. `segments` contains as many entries as the transform IR + /// op has results. Each entry is a reference to a contiguous segment in + /// the `operations` list that contains the pointers to operations. This + /// allows for operations to be stored contiguously without nested vectors and + /// for different segments to be set in any order. + SmallVector, 2> segments; + SmallVector operations; +}; + +} // namespace transform +} // namespace mlir + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" + +#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -0,0 +1,52 @@ +//===- TransformInterfaces.td - Transform Op interfaces ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces for transformation-related-ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD + +include "mlir/IR/OpBase.td" + +def TransformOpInterface : OpInterface<"TransformOpInterface"> { + let description = [{ + This interface is to be implemented by operations that identify + transformations to be performed on other operations. The former are referred + to as transform IR operations. The latter are referred to as payload IR + operations. Such transform IR operations provide a fine-grain control + mechanism over how transformations are applied by using and defining + transform IR values, referred to as handles, that correspond to sets of + operations in the payload IR. Transformations are applied starting from the + operations identified by handles, but may affect other operations as well. + Further restrictions may be imposed by flows that rely on transform IR + operations to control transformations. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Applies the transformation represented by the current operation. This + accepts as arguments the object that must be populated with results of + the current transformation and a transformation state object that can be + used for queries, e.g., to obtain the list of operations on which the + transformation represented by the current op is targeted. + }], + /*returnType=*/"::mlir::LogicalResult", + /*name=*/"apply", + /*arguments=*/(ins + "::mlir::transform::TransformResults &":$transformResults, + "::mlir::transform::TransformState &":$state + )>, + ]; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -51,6 +51,7 @@ #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -92,6 +93,7 @@ shape::ShapeDialect, sparse_tensor::SparseTensorDialect, tensor::TensorDialect, + transform::TransformDialect, tosa::TosaDialect, x86vector::X86VectorDialect>(); // clang-format on diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -27,6 +27,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Transform) add_subdirectory(Utils) add_subdirectory(Vector) add_subdirectory(X86Vector) diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRTransformDialect + TransformDialect.cpp + TransformInterfaces.cpp + + DEPENDS + MLIRTransformDialectIncGen + MLIRTransformInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRIR + ) diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -0,0 +1,15 @@ +//===- TransformDialect.cpp - Transform Dialect Definition ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" + +using namespace mlir; + +void transform::TransformDialect::initialize() {} diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -0,0 +1,140 @@ +//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/SmallPtrSet.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// TransformState +//===----------------------------------------------------------------------===// + +constexpr const Value transform::TransformState::kTopLevelValue; + +transform::TransformState::TransformState(Operation *root) { + operationMapping[kTopLevelValue].push_back(root); +} + +Operation *transform::TransformState::getTopLevel() const { + return operationMapping.lookup(kTopLevelValue).front(); +} + +ArrayRef +transform::TransformState::getPayloadOps(Value value) const { + auto iter = operationMapping.find(value); + assert(iter != operationMapping.end() && "unknown handle"); + return iter->getSecond(); +} + +LogicalResult +transform::TransformState::setPayloadOps(Value value, + ArrayRef targets) { + assert(value != kTopLevelValue && + "attempting to reset the transformation root"); + + if (value.use_empty()) + return success(); + + // Setting new payload for the value without cleaning it first is a misuse of + // the API, assert here. + SmallVector storedTargets(targets.begin(), targets.end()); + bool inserted = + operationMapping.insert({value, std::move(storedTargets)}).second; + assert(inserted && "value is already associated with another list"); + (void)inserted; + + // Having multiple handles to the same operation is an error in the transform + // expressed using the dialect and may be constructed by valid API calls from + // valid IR. Emit an error here. + for (Operation *op : targets) { + auto insertionResult = reverseMapping.insert({op, value}); + if (!insertionResult.second) { + InFlightDiagnostic diag = op->emitError() + << "operation tracked by two handles"; + diag.attachNote(value.getLoc()) << "handle"; + diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; + return diag; + } + } + + return success(); +} + +void transform::TransformState::removePayloadOps(Value value) { + for (Operation *op : operationMapping[value]) + reverseMapping.erase(op); + operationMapping.erase(value); +} + +void transform::TransformState::updatePayloadOps( + Value value, function_ref callback) { + auto it = operationMapping.find(value); + assert(it != operationMapping.end() && "unknown handle"); + SmallVector &association = it->getSecond(); + SmallVector updated; + updated.reserve(association.size()); + + for (Operation *op : association) + if (Operation *updatedOp = callback(op)) + updated.push_back(updatedOp); + + std::swap(association, updated); +} + +LogicalResult +transform::TransformState::applyTransform(TransformOpInterface transform) { + transform::TransformResults results(transform->getNumResults()); + if (failed(transform.apply(results, *this))) + return failure(); + + for (Value target : transform->getOperands()) + removePayloadOps(target); + + for (auto &en : llvm::enumerate(transform->getResults())) + if (failed(setPayloadOps(en.value(), results.get(en.index())))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TransformResults +//===----------------------------------------------------------------------===// + +transform::TransformResults::TransformResults(unsigned numSegments) { + segments.resize(numSegments, + ArrayRef(nullptr, static_cast(0))); +} + +void transform::TransformResults::set(OpResult value, + ArrayRef ops) { + unsigned position = value.getResultNumber(); + assert(position < segments.size() && + "setting results for a non-existent handle"); + assert(segments[position].data() == nullptr && "results already set"); + unsigned start = operations.size(); + llvm::append_range(operations, ops); + segments[position] = makeArrayRef(operations).drop_front(start); +} + +ArrayRef +transform::TransformResults::get(unsigned resultNumber) const { + assert(resultNumber < segments.size() && + "querying results for a non-existent handle"); + assert(segments[resultNumber].data() != nullptr && "querying unset results"); + return segments[resultNumber]; +} + +//===----------------------------------------------------------------------===// +// Generated interface implementation. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" diff --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | FileCheck %s + +// These ops are defined by a test extension but should be okay to roundtrip. + +// CHECK: transform.test_transform_op +transform.test_transform_op + +// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"} +%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } + +// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42] +transform.test_consume_operand_if_matches_param_or_fail %0[42] diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +// expected-remark @below {{applying transformation}} +transform.test_transform_op + +// ----- + +%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +// expected-remark @below {{succeeded}} +transform.test_consume_operand_if_matches_param_or_fail %0[42] + +// ----- + +%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +// expected-error @below {{expected the operand to be associated with 21 got 42}} +transform.test_consume_operand_if_matches_param_or_fail %0[21] + +// ----- + +// expected-error @below {{operation tracked by two handles}} +%0 = transform.test_produce_param_or_forward_operand 42 +// expected-note @below {{handle}} +%1 = transform.test_produce_param_or_forward_operand from %0 +// expected-note @below {{handle}} +%2 = transform.test_produce_param_or_forward_operand from %0 +transform.test_consume_operand_if_matches_param_or_fail %1[42] +transform.test_consume_operand_if_matches_param_or_fail %2[42] diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -11,4 +11,5 @@ add_subdirectory(Tensor) add_subdirectory(Test) add_subdirectory(Tosa) +add_subdirectory(Transform) add_subdirectory(Vector) diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -0,0 +1,20 @@ +set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td) +mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls) +mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen) + +add_mlir_library(MLIRTestTransformDialect + TestTransformDialectExtension.cpp + TestTransformDialectInterpreter.cpp + + EXCLUDE_FROM_LIBMLIR + + DEPENDS + MLIRTestTransformDialectExtensionIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRPDL + MLIRTransformDialect +) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -0,0 +1,33 @@ +//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension of the MLIR Transform dialect for testing +// purposes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class DialectRegistry; +} // namespace mlir + +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.h.inc" + +namespace test { +/// Registers the test extension to the Transform dialect. +void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry); +} // namespace test + +#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -0,0 +1,107 @@ +//===- TestTransformDialectExtension.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension of the MLIR Transform dialect for testing +// purposes. +// +//===----------------------------------------------------------------------===// + +#include "TestTransformDialectExtension.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +using namespace mlir; + +namespace { +/// Simple transform op defined outside of the dialect. Just emits a remark when +/// applied. +class TestTransformOp + : public Op { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) + + using Op::Op; + + static ArrayRef getAttributeNames() { return {}; } + + static constexpr llvm::StringLiteral getOperationName() { + return llvm::StringLiteral("transform.test_transform_op"); + } + + LogicalResult apply(transform::TransformResults &results, + transform::TransformState &state) { + emitRemark() << "applying transformation"; + return success(); + } + + static ParseResult parse(OpAsmParser &parser, OperationState &state) { + return success(); + } + + void print(OpAsmPrinter &printer) {} +}; +} // namespace + +LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + if (getOperation()->getNumOperands() != 0) { + results.set(getResult().cast(), getOperand(0).getDefiningOp()); + } else { + results.set(getResult().cast(), + reinterpret_cast(*parameter())); + } + return success(); +} + +LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { + if (parameter().hasValue() ^ (getNumOperands() != 1)) + return emitOpError() << "expects either a parameter or an operand"; + return success(); +} + +LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef payload = state.getPayloadOps(getOperand()); + assert(payload.size() == 1 && "expected a single target op"); + auto value = reinterpret_cast(payload[0]); + if (value != parameter()) { + return emitOpError() << "expected the operand to be associated with " + << parameter() << " got " << value; + } + + emitRemark() << "succeeded"; + return success(); +} + +namespace { +/// Test extension of the Transform dialect. Registers additional ops and +/// declares PDL as dependent dialect since the additional ops are using PDL +/// types for operands and results. +class TestTransformDialectExtension + : public transform::TransformDialectExtension< + TestTransformDialectExtension> { +public: + TestTransformDialectExtension() { + declareDependentDialect(); + registerTransformOps(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.cpp.inc" + +void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -0,0 +1,41 @@ +//===- TestTransformDialectExtension.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operations that are injected into the Transform +// dialect through the extension mechanism, as a test. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD +#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" + +def TestProduceParamOrForwardOperandOp + : Op]> { + let arguments = (ins Optional:$operand, + OptionalAttr:$parameter); + let results = (outs PDL_Operation:$res); + let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict"; + let cppNamespace = "::mlir::test"; + let hasVerifier = 1; +} + +def TestConsumeOperandIfMatchesParamOrFail + : Op]> { + let arguments = (ins PDL_Operation:$operand, I64Attr:$parameter); + let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; + let cppNamespace = "::mlir::test"; +} + +#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -0,0 +1,57 @@ +//===- TestTransformDialectInterpreter.cpp --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a test pass that interprets Transform dialect operations in +// the module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Simple pass that applies transform dialect ops directly contained in a +/// module. +class TestTransformDialectInterpreterPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestTransformDialectInterpreterPass) + + StringRef getArgument() const override { + return "test-transform-dialect-interpreter"; + } + + StringRef getDescription() const override { + return "apply transform dialect operations one by one"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + transform::TransformState state(module); + for (auto op : + module.getBody()->getOps()) { + if (failed(state.applyTransform(op))) + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +/// Registers the test pass for applying transform dialect ops. +void registerTestTransformDialectInterpreterPass() { + PassRegistration reg; +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Transform/lit.local.cfg b/mlir/test/lib/Dialect/Transform/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Transform/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove('.td') \ No newline at end of file diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -33,5 +33,6 @@ // CHECK-NEXT: tensor // CHECK-NEXT: test // CHECK-NEXT: tosa +// CHECK-NEXT: transform // CHECK-NEXT: vector // CHECK-NEXT: x86vector diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -30,6 +30,7 @@ MLIRTestPass MLIRTestReducer MLIRTestRewrite + MLIRTestTransformDialect MLIRTestTransforms MLIRVectorTestPasses ) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -107,12 +107,14 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); +void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); } // namespace test } // namespace mlir namespace test { void registerTestDialect(DialectRegistry &); +void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -196,6 +198,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); + mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); } #endif @@ -209,6 +212,7 @@ registerAllDialects(registry); #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); + ::test::registerTestTransformDialectExtension(registry); #endif return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6019,6 +6019,7 @@ ":TensorTransforms", ":TosaDialect", ":TosaToLinalg", + ":TransformDialect", ":Transforms", ":TransformsPassIncGen", ":VectorOps", @@ -6079,6 +6080,7 @@ "//mlir/test:TestShapeDialect", "//mlir/test:TestTensor", "//mlir/test:TestTosaDialect", + "//mlir/test:TestTransformDialect", "//mlir/test:TestTransforms", "//mlir/test:TestTypeDialect", "//mlir/test:TestVector", @@ -7583,6 +7585,71 @@ ], ) +td_library( + name = "TransformDialectTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]), + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformDialectInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-op-interface-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc", + ), + ( + [ + "-gen-op-interface-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td", + deps = [":TransformDialectTdFiles"], +) + +gentbl_cc_library( + name = "TransformDialectIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformDialect.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformDialect.td", + deps = [":TransformDialectTdFiles"], +) + +cc_library( + name = "TransformDialect", + srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]), + deps = [ + ":IR", + ":Support", + ":TransformDialectIncGen", + ":TransformDialectInterfacesIncGen", + "//llvm:Support", + ], +) + td_library( name = "ComplexOpsTdFiles", srcs = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -192,6 +192,51 @@ ], ) +td_library( + name = "TransformDialectTdFiles", + srcs = glob(["lib/Dialect/Transform/*.td"]), + deps = [ + "//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "TestTransformDialectExtensionIncGen", + strip_include_prefix = "lib/Dialect/Transform", + tbl_outs = [ + ( + ["-gen-op-decls"], + "lib/Dialect/Transform/TestTransformDialectExtension.h.inc", + ), + ( + ["-gen-op-defs"], + "lib/Dialect/Transform/TestTransformDialectExtension.cpp.inc", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "lib/Dialect/Transform/TestTransformDialectExtension.td", + test = True, + deps = [ + ":TransformDialectTdFiles", + "//mlir:PDLDialectTdFiles", + "//mlir:TransformDialectTdFiles", + ], +) + +cc_library( + name = "TestTransformDialect", + srcs = glob(["lib/Dialect/Transform/*.cpp"]), + hdrs = glob(["lib/Dialect/Transform/*.h"]), + includes = ["lib/Dialect/Transform"], + deps = [ + ":TestTransformDialectExtensionIncGen", + "//mlir:IR", + "//mlir:PDLDialect", + "//mlir:Pass", + "//mlir:TransformDialect", + ], +) + cc_library( name = "TestDialect", srcs = glob(["lib/Dialect/Test/*.cpp"]),