diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -26,5 +26,6 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Transform) add_subdirectory(Vector) add_subdirectory(X86Vector) diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +# The dialect does not have its own ops, so just generate the dialect files. +set(LLVM_TARGET_DEFINITIONS TransformDialect.td) +mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform) +mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform) +add_public_tablegen_target(MLIRTransformDialectIncGen) +add_dependencies(mlir-headers MLIRTransformDialectIncGen) + +add_mlir_interface(TransformInterfaces) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -0,0 +1,77 @@ +//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" + +namespace mlir { +namespace transform { + +/// Base class for extensions of the Transform dialect that supports injecting +/// operations into the Transform dialect at load time. Concrete extensions are +/// expected to derive this class and register operations in the constructor. +/// They can be registered with the DialectRegistry and automatically applied +/// to the Transform dialect when it is loaded. +template +class TransformDialectExtension + : public DialectExtension { + using Initializer = std::function; + using DialectLoader = std::function; + +public: + /// Extension application hook. Actually loads the dependent dialects and + /// registers the additional operations. Not expected to be called directly. + void apply(MLIRContext *context, TransformDialect *transformDialect, + ExtraDialects *...) const final { + for (const DialectLoader &loader : dialectLoaders) + loader(context); + for (const Initializer &init : opInitializers) + init(transformDialect); + } + +protected: + /// Injects the operation into the Transform dialect. + template + void registerTransformOp() { + opInitializers.push_back([](TransformDialect *transformDialect) { + RegisteredOperationName::insert(*transformDialect); + }); + } + + /// Injects the operations into the Transform dialect. + template + void registerTransformOps() { + // FIXME: replace with C++17 fold expression when available. + (void)std::initializer_list{(registerTransformOp(), 0)...}; + } + + /// Declares that this Transform dialect extension depends on the dialect + /// provided as template parameter. When the Transform dialect is loaded, + /// dependent dialects will be loaded as well. This is intended for dialects + /// that contain attributes and types used in creation and canonicalization of + /// the injected operations. + template + void declareDependentDialect() { + dialectLoaders.push_back( + [](MLIRContext *context) { context->loadDialect(); }); + } + +private: + SmallVector opInitializers; + SmallVector dialectLoaders; +}; + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -0,0 +1,62 @@ +//===- TransformDialect.td - Transform dialect definition --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT + +include "mlir/IR/OpBase.td" + +def Transform_Dialect : Dialect { + let summary = "Fine-grain transformation control dialect"; + let description = [{ + This dialect provides operations that can be used to control transformation + of the IR using a different portion of the IR. It refers to the IR being + transformed as payload IR, and to the IR guiding the transformation as + transform IR. + + The main use case for this dialect is orchestrating fine-grain + transformations on individual operations or sets thereof. For example, it + may involve finding loop-like operations with specific properties (e.g., + large size) in the payload IR, applying loop tiling to those and only those + operations, and then applying loop unrolling to the inner loops produced + by the previous transformations. As such, it is not intended as a + replacement for the pass infrastructure, nor for the pattern rewriting + infrastructure. In the most common case, the transform IR will be processed + and applied to the payload IR by a pass. Transformations expressed by the + transform dialect may be implemented using the pattern infrastructure or any + other relevant MLIR component. + + The following IR gives a rough idea of what the operations in this dialect + may look like. + + ``` + %0 = transform.loop.find { size > 42 } + %1:2 = transform.loop.tile { tile_sizes = [2,3,4] } + transform.loop.unroll %1#1 + ``` + + The values defined by operations in this dialect correspond to (groups of) + operations in the payload IR. In the example above, `%0` corresponds to the + set of loops found in the payload IR that satisfy the condition, and `%1` + correspond to groups of outer and inner loops, respectively, produced by + the tiling transformation. + + This dialect is designed to be extensible, that is, clients of this dialect + are allowed to inject additional operations into this dialect using the + `TransformDialectExtension` mechanism. This allows the dialect to avoid a + dependency on the implementation of the transformation as well as to avoid + introducing dialect-specific transform dialects. In the example above, + the operations may have been injected by a notional `loop` dialect rather + than defined in this dialect, hence the common prefix. + }]; + + let name = "transform"; + let cppNamespace = "::mlir::transform"; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -0,0 +1,118 @@ +//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" + +namespace mlir { +namespace transform { + +class TransformOpInterface; + +/// The state maintained across applications of various ops implementing the +/// TransformOpInterface. The operations implementing this interface and the +/// surrounding structure are referred to as transform IR. The operations to +/// which transformations apply are referred to as payload IR. The state thus +/// contains the mapping between values defined in the transform IR ops and +/// payload IR ops. It assumes that each value in the transform IR can be used +/// at most once (since transformations are likely to change the payload IR ops +/// the value corresponds to). Checks that transform IR values correspond to +/// disjoint sets of payload IR ops throughout the transformation. +/// +/// A reference to this class is passed as an argument to "apply" methods of the +/// transform op interface. Thus the "apply" method can call +/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations +/// associated with its operand and subject to transformation. The method is +/// expected to populate the `TransformResults` class instance in order to +/// update the mapping. The `applyTransform` method takes care of propagating +/// the state of `TransformResults` into the instance of this class. +class TransformState { + /// Mapping between a Value in the transform IR and the corresponding set of + /// operations in the payload IR. + using TransformOpMapping = DenseMap>; + +public: + /// Creates a state for the transformation rooted at the given op. + explicit TransformState(Operation *root); + + /// Returns the op at which the transformation state is rooted. This is + /// typically helpful for transformations that apply globally. + Operation *getTopLevel() const; + + /// Returns the list of ops that the given transform IR value corresponds to. + /// This is helpful for transformations that apply to a particular handle. + ArrayRef getPayloadOps(Value value) const; + + /// Applies the transformation specified by the given transform op and updates + /// the state accordingly. + LogicalResult applyTransform(TransformOpInterface transform); + +private: + /// Identifier for storing top-level value in the `operations` mapping. + static constexpr Value kTopLevelValue = Value(); + + /// Sets the payload IR ops associated with the given transform IR value. + /// Fails if this would result in multiple transform IR values with uses + /// corresponding to the same payload IR ops. + LogicalResult setPayloadOps(Value value, ArrayRef targets); + + /// Forgets the payload IR ops associated with the given transform IR value. + void removePayloadOps(Value value); + + /// Updates the payload IR ops associated with the given transform IR value. + /// The callback function is called once per associated operation and is + /// expected to return the modified operation or nullptr. In the latter case, + /// the corresponding operation is no longer associated with the transform IR + /// value. + void updatePayloadOps(Value value, + function_ref callback); + + /// The mapping between payload IR values and transform IR ops. + TransformOpMapping operationMapping; +}; + +/// Local mapping between values defined by a specific op implementing the +/// TransformOpInterface and the payload IR ops they correspond to. +class TransformResults { + friend class TransformState; + +public: + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given list of payload IR ops. Each result must be set + /// by the transformation exactly once. + void set(OpResult value, ArrayRef ops); + +private: + /// Creates an instance of TransformResults that expects mappings for + /// `numSegments` values. + explicit TransformResults(unsigned numSegments); + + /// Gets the list of operations associated with the result at the given + /// position. + ArrayRef get(unsigned position) const; + + /// Storage for pointers to payload IR ops that are associated with results of + /// a transform IR op. `segments` contains as many entries as the transform IR + /// op has results. Each entry is a reference to a contiguous segment in + /// the `operations` list that contains the pointers to operations. This + /// allows for operations to be stored contiguously without nested vectors and + /// for different segments to be set in any order. + SmallVector, 2> segments; + SmallVector operations; +}; + +} // namespace transform +} // namespace mlir + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" + +#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -0,0 +1,51 @@ +//===- TransformInterfaces.td - Transform Op interfaces ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces for transformation-related-ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD + +include "mlir/IR/OpBase.td" + +def TransformOpInterface : OpInterface<"TransformOpInterface"> { + let description = [{ + This interface is to be implemented by operations that identify + transformations to be performed on other operations. The former are referred + to as transform IR operations. The latter are referred to as payload IR + operations. Such transform IR operations provide a fine-grain control + mechanism over how transformations are applied by using and defining + transform IR values, referred to as handles, that correspond to sets of + operations in the payload IR. Transformations are applied starting from the + operations identified by handles, but may affect other operations as well. + Further restrictions may be imposed by flows that rely on transform IR + operations to control transformations. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Applies the transformation represented by the current + operation. This accepts as arguments the object that must be populated + with results of the current transformation and a transformation state + object that can be used for queries, e.g., to obtain the list of + operations on which the transformation represented by the current op is + targeted.}], + /*returnType=*/"LogicalResult", + /*name=*/"apply", + /*arguments=*/(ins + "::mlir::transform::TransformResults &":$transformResults, + "::mlir::transform::TransformState &":$state) + >, + ]; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -50,6 +50,7 @@ #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -90,6 +91,7 @@ shape::ShapeDialect, sparse_tensor::SparseTensorDialect, tensor::TensorDialect, + transform::TransformDialect, tosa::TosaDialect, x86vector::X86VectorDialect>(); // clang-format on diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -26,6 +26,7 @@ add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Transform) add_subdirectory(Utils) add_subdirectory(Vector) add_subdirectory(X86Vector) diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRTransformDialect + TransformDialect.cpp + TransformInterfaces.cpp + + DEPENDS + MLIRTransformDialectIncGen + MLIRTransformInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRIR + ) diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -0,0 +1,15 @@ +//===- TransformDialect.cpp - Transform Dialect Definition ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" + +using namespace mlir; + +void transform::TransformDialect::initialize() {} \ No newline at end of file diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -0,0 +1,140 @@ +//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/SmallPtrSet.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// TransformState +//===----------------------------------------------------------------------===// + +constexpr const Value transform::TransformState::kTopLevelValue; + +transform::TransformState::TransformState(Operation *root) { + operationMapping[kTopLevelValue].push_back(root); +} + +Operation *transform::TransformState::getTopLevel() const { + return operationMapping.lookup(kTopLevelValue).front(); +} + +ArrayRef +transform::TransformState::getPayloadOps(Value value) const { + auto iter = operationMapping.find(value); + assert(iter != operationMapping.end() && "unknown handle"); + return iter->getSecond(); +} + +LogicalResult +transform::TransformState::setPayloadOps(Value value, + ArrayRef targets) { + assert(value != kTopLevelValue && + "attempting to reset the transformation root"); + + if (value.use_empty()) + return success(); + + SmallVector storedTargets(targets.begin(), targets.end()); + bool inserted = + operationMapping.insert({value, std::move(storedTargets)}).second; + assert(inserted && "value is already associated with another list"); + (void)inserted; + + const SmallVector ¤tOperationList = + operationMapping.lookup(value); + llvm::SmallPtrSet currentOperationSet( + currentOperationList.begin(), currentOperationList.end()); + for (const auto &kvp : operationMapping) { + if (kvp.getFirst() == value) + continue; + for (Operation *trackedOp : kvp.getSecond()) { + if (currentOperationSet.contains(trackedOp)) { + InFlightDiagnostic diag = trackedOp->emitError() + << "operation tracked by two handles"; + diag.attachNote(value.getLoc()) << "handle"; + diag.attachNote(kvp.getFirst().getLoc()) << "handle"; + return diag; + } + } + } + + return success(); +} + +void transform::TransformState::removePayloadOps(Value value) { + operationMapping.erase(value); +} + +void transform::TransformState::updatePayloadOps( + Value value, function_ref callback) { + auto it = operationMapping.find(value); + assert(it != operationMapping.end() && "unknown handle"); + SmallVector &association = it->getSecond(); + SmallVector updated; + updated.reserve(association.size()); + + for (Operation *op : association) + if (Operation *updatedOp = callback(op)) + updated.push_back(updatedOp); + + std::swap(association, updated); +} + +LogicalResult +transform::TransformState::applyTransform(TransformOpInterface transform) { + transform::TransformResults results(transform->getNumResults()); + if (failed(transform.apply(results, *this))) + return failure(); + + for (Value target : transform->getOperands()) + removePayloadOps(target); + + for (auto &en : llvm::enumerate(transform->getResults())) + if (failed(setPayloadOps(en.value(), results.get(en.index())))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TransformResults +//===----------------------------------------------------------------------===// + +transform::TransformResults::TransformResults(unsigned numSegments) { + segments.resize(numSegments, + ArrayRef(nullptr, static_cast(0))); +} + +void transform::TransformResults::set(OpResult value, + ArrayRef ops) { + unsigned position = value.getResultNumber(); + assert(position < segments.size() && + "setting results for a non-existent handle"); + assert(segments[position].data() == nullptr && "results already set"); + unsigned start = operations.size(); + llvm::append_range(operations, ops); + segments[position] = makeArrayRef(operations).drop_front(start); +} + +ArrayRef +transform::TransformResults::get(unsigned position) const { + assert(position < segments.size() && + "querying results for a non-existent handle"); + assert(segments[position].data() != nullptr && "querying unset results"); + return segments[position]; +} + +//===----------------------------------------------------------------------===// +// Generated interface implementation. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" diff --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | FileCheck %s + +// These ops are defined by a test extension but should be okay to roundtrip. + +// CHECK: transform.test_transform_op +transform.test_transform_op + +// CHECK: = transform.test_produce_42_or_forward_operand {foo = "bar"} +%0 = transform.test_produce_42_or_forward_operand { foo = "bar" } + +// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42] +transform.test_consume_operand_if_matches_param_or_fail %0[42] diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +// expected-remark @below {{applying transformation}} +transform.test_transform_op + +// ----- + +%0 = transform.test_produce_42_or_forward_operand { foo = "bar" } +// expected-remark @below {{succeeded}} +transform.test_consume_operand_if_matches_param_or_fail %0[42] + +// ----- + +%0 = transform.test_produce_42_or_forward_operand { foo = "bar" } +// expected-error @below {{expected the operand to be associated with 21 got 42}} +transform.test_consume_operand_if_matches_param_or_fail %0[21] + +// ----- + +// expected-error @below {{operation tracked by two handles}} +%0 = transform.test_produce_42_or_forward_operand +// expected-note @below {{handle}} +%1 = transform.test_produce_42_or_forward_operand from %0 +// expected-note @below {{handle}} +%2 = transform.test_produce_42_or_forward_operand from %0 +transform.test_consume_operand_if_matches_param_or_fail %1[42] +transform.test_consume_operand_if_matches_param_or_fail %2[42] diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -36,6 +36,11 @@ mlir_tablegen(TestPatterns.inc -gen-rewriters) add_public_tablegen_target(MLIRTestOpsIncGen) +set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td) +mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls) +mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen) + # Exclude tests from libMLIR.so add_mlir_library(MLIRTestDialect TestAttributes.cpp @@ -43,6 +48,8 @@ TestInterfaces.cpp TestPatterns.cpp TestTraits.cpp + TestTransformDialectExtension.cpp + TestTransformDialectInterpreter.cpp TestTypes.cpp EXCLUDE_FROM_LIBMLIR @@ -50,8 +57,9 @@ DEPENDS MLIRTestAttrDefIncGen MLIRTestInterfaceIncGen - MLIRTestTypeDefIncGen MLIRTestOpsIncGen + MLIRTestTransformDialectExtensionIncGen + MLIRTestTypeDefIncGen LINK_LIBS PUBLIC MLIRControlFlowInterfaces @@ -69,6 +77,7 @@ MLIRPass MLIRReduce MLIRTensor + MLIRTransformDialect MLIRTransformUtils MLIRTransforms ) diff --git a/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.h @@ -0,0 +1,33 @@ +//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension of the MLIR Transform dialect for testing +// purposes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class DialectRegistry; +} // namespace mlir + +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.h.inc" + +namespace test { +/// Registers the test extension to the Transform dialect. +void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry); +} // namespace test + +#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H diff --git a/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.cpp @@ -0,0 +1,100 @@ +//===- TestTransformDialectExtension.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension of the MLIR Transform dialect for testing +// purposes. +// +//===----------------------------------------------------------------------===// + +#include "TestTransformDialectExtension.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +using namespace mlir; + +namespace { +/// Simple transform op defined outside of the dialect. Just emits a remark when +/// applied. +class TestTransformOp + : public Op { +public: + using Op::Op; + + static ArrayRef getAttributeNames() { return {}; } + + static constexpr llvm::StringLiteral getOperationName() { + return llvm::StringLiteral("transform.test_transform_op"); + } + + LogicalResult apply(transform::TransformResults &results, + transform::TransformState &state) { + emitRemark() << "applying transformation"; + return success(); + } + + static ParseResult parse(OpAsmParser &parser, OperationState &state) { + return success(); + } + + void print(OpAsmPrinter &printer) {} +}; +} // namespace + +LogicalResult mlir::test::TestProduce42OrForwardOperandOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + if (getOperation()->getNumOperands() != 0) { + results.set(getResult().cast(), getOperand(0).getDefiningOp()); + } else { + results.set(getResult().cast(), + reinterpret_cast(42)); + } + return success(); +} + +LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef payload = state.getPayloadOps(getOperand()); + assert(payload.size() == 1 && "expected a single target op"); + auto value = reinterpret_cast(payload[0]); + if (value != parameter()) { + return emitOpError() << "expected the operand to be associated with " + << parameter() << " got " << value; + } + + emitRemark() << "succeeded"; + return success(); +} + +namespace { +/// Test extension of the Transform dialect. Registers additional ops and +/// declares PDL as dependent dialect since the additional ops are using PDL +/// types for operands and results. +class TestTransformDialectExtension + : public transform::TransformDialectExtension< + TestTransformDialectExtension> { +public: + TestTransformDialectExtension() { + declareDependentDialect(); + registerTransformOp(); + registerTransformOps< +#define GET_OP_LIST +#include "TestTransformDialectExtension.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "TestTransformDialectExtension.cpp.inc" + +void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTransformDialectExtension.td @@ -0,0 +1,39 @@ +//===- TestTransformDialectExtension.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operations that are injected into the Transform +// dialect through the extension mechanism, as a test. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD +#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" + +def TestProduce42OrForwardOperandOp + : Op]> { + let arguments = (ins Optional:$operand); + let results = (outs PDL_Operation:$res); + let assemblyFormat = "(`from` $operand^)? attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestConsumeOperandIfMatchesParamOrFail + : Op]> { + let arguments = (ins PDL_Operation:$operand, I64Attr:$parameter); + let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; + let cppNamespace = "::mlir::test"; +} + +#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/mlir/test/lib/Dialect/Test/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Test/TestTransformDialectInterpreter.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTransformDialectInterpreter.cpp @@ -0,0 +1,53 @@ +//===- TestTransformDialectInterpreter.cpp --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a test pass that interprets Transform dialect operations in +// the module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Simple pass that applies transform dialect ops directly contained in a +/// module. +class TestTransformDialectInterpreterPass + : public PassWrapper> { +public: + StringRef getArgument() const override { + return "test-transform-dialect-interpreter"; + } + + StringRef getDescription() const override { + return "apply transform dialect operations one by one"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + transform::TransformState state(module); + for (auto op : + module.getBody()->getOps()) { + if (failed(state.applyTransform(op))) + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +/// Registers the test pass for applying transform dialect ops. +void registerTestTransformDialectInterpreterPass() { + PassRegistration reg; +} +} // namespace test +} // namespace mlir \ No newline at end of file diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -32,5 +32,6 @@ // CHECK-NEXT: tensor // CHECK-NEXT: test // CHECK-NEXT: tosa +// CHECK-NEXT: transform // CHECK-NEXT: vector // CHECK-NEXT: x86vector diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -107,12 +107,14 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); +void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); } // namespace test } // namespace mlir namespace test { void registerTestDialect(DialectRegistry &); +void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -196,6 +198,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); + mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); } #endif @@ -209,6 +212,7 @@ registerAllDialects(registry); #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); + ::test::registerTestTransformDialectExtension(registry); #endif return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5980,6 +5980,7 @@ ":TensorTransforms", ":TosaDialect", ":TosaToLinalg", + ":TransformDialect", ":Transforms", ":TransformsPassIncGen", ":VectorOps", @@ -7543,6 +7544,71 @@ ], ) +td_library( + name = "TransformDialectTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]), + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformDialectInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-op-interface-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc", + ), + ( + [ + "-gen-op-interface-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td", + deps = [":TransformDialectTdFiles"], +) + +gentbl_cc_library( + name = "TransformDialectIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformDialect.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformDialect.td", + deps = [":TransformDialectTdFiles"], +) + +cc_library( + name = "TransformDialect", + srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]), + deps = [ + ":IR", + ":Support", + ":TransformDialectIncGen", + ":TransformDialectInterfacesIncGen", + "//llvm:Support", + ], +) + td_library( name = "ComplexOpsTdFiles", srcs = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -192,6 +192,29 @@ ], ) +gentbl_cc_library( + name = "TestTransformDialectExtensionIncGen", + strip_include_prefix = "lib/Dialect/Test", + tbl_outs = [ + ( + ["-gen-op-decls"], + "lib/Dialect/Test/TestTransformDialectExtension.h.inc", + ), + ( + ["-gen-op-defs"], + "lib/Dialect/Test/TestTransformDialectExtension.cpp.inc", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "lib/Dialect/Test/TestTransformDialectExtension.td", + test = True, + deps = [ + ":TestOpTdFiles", + "//mlir:PDLDialectTdFiles", + "//mlir:TransformDialectTdFiles", + ], +) + cc_library( name = "TestDialect", srcs = glob(["lib/Dialect/Test/*.cpp"]), @@ -203,6 +226,7 @@ ":TestAttrDefsIncGen", ":TestInterfacesIncGen", ":TestOpsIncGen", + ":TestTransformDialectExtensionIncGen", ":TestTypeDefsIncGen", "//llvm:Support", "//mlir:ArithmeticDialect", @@ -218,12 +242,14 @@ "//mlir:LLVMDialect", "//mlir:LinalgInterfaces", "//mlir:LinalgOps", + "//mlir:PDLDialect", "//mlir:Pass", "//mlir:Reducer", "//mlir:SideEffects", "//mlir:StandardOpsTransforms", "//mlir:Support", "//mlir:TensorDialect", + "//mlir:TransformDialect", "//mlir:TransformUtils", "//mlir:Transforms", ],