diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -1,8 +1,13 @@ -# The dialect does not have its own ops, so just generate the dialect files. +# Generate the dialect files from the dialect .td. +# +# TODO: Make it possible to use XDialect instead of XOpsDialect in +# add_mlir_dialect. set(LLVM_TARGET_DEFINITIONS TransformDialect.td) mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform) mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform) add_public_tablegen_target(MLIRTransformDialectIncGen) add_dependencies(mlir-headers MLIRTransformDialectIncGen) +add_mlir_dialect(TransformOps transform) + add_mlir_interface(TransformInterfaces) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -161,6 +161,7 @@ let name = "transform"; let cppNamespace = "::mlir::transform"; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let extraClassDeclaration = [{ // Make addOperations available to the TransformDialectExtension class. @@ -172,4 +173,9 @@ }]; } +// Base class for ops that belong to the tranfsorm dialect. Ops defined in +// extensions of this dialect may also use this. +class TransformDialectOp traits = []> + : Op; + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -33,6 +33,14 @@ /// expected to populate the `TransformResults` class instance in order to /// update the mapping. The `applyTransform` method takes care of propagating /// the state of `TransformResults` into the instance of this class. +/// +/// When applying transform IR operations with regions, the client is expected +/// to create a RegionScope RAII object to create a new "stack frame" for +/// values defined inside the region. The mappings from and to these values will +/// be automatically dropped when the object goes out of scope, typically at the +/// end of the "apply" function of the parent operation. If a region contains +/// blocks with arguments, the client can map those arguments to payload IR ops +/// using "mapBlockArguments". class TransformState { /// Mapping between a Value in the transform IR and the corresponding set of /// operations in the payload IR. @@ -42,9 +50,19 @@ /// currently associated with. using TransformOpReverseMapping = DenseMap; + /// Bidirectional mappings between transform IR values and payload IR + /// operations. + struct Mappings { + TransformOpMapping direct; + TransformOpReverseMapping reverse; + }; + public: - /// Creates a state for the transformation rooted at the given op. - explicit TransformState(Operation *root); + /// Creates a state for transform ops living in the given region. The parent + /// operation of the region. The second argument points to the root operation + /// in the payload IR beind transformed, which may or may not contain the + /// region with transform ops. + TransformState(Region ®ion, Operation *root); /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. @@ -58,10 +76,96 @@ /// the state accordingly. LogicalResult applyTransform(TransformOpInterface transform); + /// Records the mapping between a block argument in the transform IR and a + /// list of operations in the payload IR. The arguments must be defined in + /// blocks of the currently processed transform IR region, typically after a + /// region scope is defined. + LogicalResult mapBlockArguments(BlockArgument argument, + ArrayRef operations) { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(argument.getParentRegion() == regionStack.back() && + "mapping block arguments from a region other than the active one"); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return setPayloadOps(argument, operations); + } + + // Forward declarations to support limited visibility. + class RegionScope; + + /// Creates a new region scope for the given region. The region is expected to + /// be nested in the currently processed region. + // Implementation note: this method is inline but implemented outside of the + // class body to comply with visibility and full-declaration requirements. + inline RegionScope make_region_scope(Region ®ion); + + /// A RAII object maintaining a "stack frame" for a transform IR region. When + /// applying a transform IR operation that contains a region, the caller is + /// expected to create a RegionScope before applying the ops contained in the + /// region. This ensures that the mappings between values defined in the + /// transform IR region and payload IR operations are cleared when the region + /// processing ends; such values cannot be accessed outside the region. + class RegionScope { + public: + /// Forgets the mapping from or to values defined in the associated + /// transform IR region. + ~RegionScope() { + state.mappings.erase(region); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + state.regionStack.pop_back(); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + + private: + /// Creates a new scope for mappings between values defined in the given + /// transform IR region and payload IR operations. + RegionScope(TransformState &state, Region ®ion) + : state(state), region(®ion) { + auto res = state.mappings.try_emplace(this->region); + assert(res.second && "the region scope is already present"); + (void)res; +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(state.regionStack.back()->isProperAncestor(®ion) && + "scope started at a non-nested region"); + state.regionStack.push_back(®ion); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + + /// Back-reference to the transform state. + TransformState &state; + + /// The region this scope is associated with. + Region *region; + + friend RegionScope TransformState::make_region_scope(Region &); + }; + friend class RegionScope; + private: /// Identifier for storing top-level value in the `operations` mapping. static constexpr Value kTopLevelValue = Value(); + /// Returns the mappings frame for the reigon in which the value is defined. + const Mappings &getMapping(Value value) const { + return const_cast(this)->getMapping(value); + } + Mappings &getMapping(Value value) { + auto it = mappings.find(value.getParentRegion()); + assert(it != mappings.end() && + "trying to find a mapping for a value from an unmapped region"); + return it->second; + } + + /// Returns the mappings frame for the region in which the operation resides. + const Mappings &getMapping(Operation *operation) const { + return const_cast(this)->getMapping(operation); + } + Mappings &getMapping(Operation *operation) { + auto it = mappings.find(operation->getParentRegion()); + assert(it != mappings.end() && + "trying to find a mapping for an operation from an unmapped region"); + return it->second; + } + /// Sets the payload IR ops associated with the given transform IR value. /// Fails if this would result in multiple transform IR values with uses /// corresponding to the same payload IR ops. For example, a hypothetical @@ -88,9 +192,19 @@ void updatePayloadOps(Value value, function_ref callback); - /// The mapping between payload IR values and transform IR ops. - TransformOpMapping operationMapping; - TransformOpReverseMapping reverseMapping; + /// The mappings between transform IR values and payload IR ops, aggregated by + /// the region in which the transform IR values are defined. + llvm::SmallDenseMap mappings; + + /// The top-level operation that contains all payload IR, typically a module. + Operation *topLevel; + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + /// A stack of nested regions that are being processed in the transform IR. + /// Each region must be an ancestor of the following regions in this list. + /// These are also the keys for "mappings". + SmallVector regionStack; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; /// Local mapping between values defined by a specific op implementing the @@ -123,6 +237,10 @@ SmallVector operations; }; +TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { + return RegionScope(*this, region); +} + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -0,0 +1,20 @@ +//===- TransformDialect.h - Transform dialect operations --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/IR/TransformOps.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -0,0 +1,78 @@ +//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS + +include "mlir/IR/OpAsmInterface.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" + +def SequenceOp : TransformDialectOp<"sequence", + [DeclareOpInterfaceMethods, OpAsmOpInterface, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + let summary = "Contains a sequence of other transform ops to apply"; + let description = [{ + The transformations indicated by the sequence are applied in order of their + appearance. Each value produced by a transformation within the sequence + corresponds to an operation or a group of operations in the payload IR. + Each value may be used at most once by another transformation operation as + the transformation is likely to replace the transformed operation with + another operation or a group thereof. In such cases, the transformation + operation is expected to produce a new value to denote the newly produced + operations that can be transformed further. During application, if any + transformation in the sequence fails, the entire sequence fails immediately + leaving the payload IR in potentially invalid state, i.e., this operation + offers no transformation rollback capabilities. + + The entry block of this operation has a single argument that maps to either + the operand if provided or the top-level container operation of the payload + IR, typically the root operation of the pass interpreting the transform + dialect. Operand omission is only allowed for sequences not contained in + another sequence. + }]; + + let arguments = (ins Optional:$root); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = + "($root^)? attr-dict-with-keyword regions (`:` type($results)^)?"; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + + Block *getBodyBlock() { + return &getBody().front(); + } + }]; + + let hasVerifier = 1; +} + +def YieldOp : TransformDialectOp<"yield", [Terminator]> { + let summary = "Yields operation handles from a transform IR region"; + let description = [{ + This terminator operation yields operation handles from regions of the + transform IR ops back to the containing op. It is not itself associated with + any transformation on the payload IR and is used for flow purposes only. + }]; + + let arguments = (ins Variadic:$operands); + let assemblyFormat = "operands attr-dict (`:` type($operands)^)?"; + + let builders = [ + OpBuilder<(ins), [{ + return build($_builder, $_state, ::mlir::ValueRange()); + }]> + ]; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTransformDialect TransformDialect.cpp TransformInterfaces.cpp + TransformOps.cpp DEPENDS MLIRTransformDialectIncGen @@ -8,4 +9,6 @@ LINK_LIBS PUBLIC MLIRIR + MLIRPDL + MLIRPDLInterp ) diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -7,9 +7,15 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformDialect.h" - -#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" +#include "mlir/Dialect/Transform/IR/TransformOps.h" using namespace mlir; -void transform::TransformDialect::initialize() {} +#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" + +void transform::TransformDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -19,16 +20,21 @@ constexpr const Value transform::TransformState::kTopLevelValue; -transform::TransformState::TransformState(Operation *root) { - operationMapping[kTopLevelValue].push_back(root); +transform::TransformState::TransformState(Region ®ion, Operation *root) + : topLevel(root) { + auto result = mappings.try_emplace(®ion); + assert(result.second && "the region scope is already present"); + (void)result; +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + regionStack.push_back(®ion); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } -Operation *transform::TransformState::getTopLevel() const { - return operationMapping.lookup(kTopLevelValue).front(); -} +Operation *transform::TransformState::getTopLevel() const { return topLevel; } ArrayRef transform::TransformState::getPayloadOps(Value value) const { + const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); assert(iter != operationMapping.end() && "unknown handle"); return iter->getSecond(); @@ -46,8 +52,9 @@ // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. SmallVector storedTargets(targets.begin(), targets.end()); + Mappings &mappings = getMapping(value); bool inserted = - operationMapping.insert({value, std::move(storedTargets)}).second; + mappings.direct.insert({value, std::move(storedTargets)}).second; assert(inserted && "value is already associated with another list"); (void)inserted; @@ -55,7 +62,7 @@ // expressed using the dialect and may be constructed by valid API calls from // valid IR. Emit an error here. for (Operation *op : targets) { - auto insertionResult = reverseMapping.insert({op, value}); + auto insertionResult = mappings.reverse.insert({op, value}); if (!insertionResult.second) { InFlightDiagnostic diag = op->emitError() << "operation tracked by two handles"; @@ -69,15 +76,16 @@ } void transform::TransformState::removePayloadOps(Value value) { - for (Operation *op : operationMapping[value]) - reverseMapping.erase(op); - operationMapping.erase(value); + Mappings &mappings = getMapping(value); + for (Operation *op : mappings.direct[value]) + mappings.reverse.erase(op); + mappings.direct.erase(value); } void transform::TransformState::updatePayloadOps( Value value, function_ref callback) { - auto it = operationMapping.find(value); - assert(it != operationMapping.end() && "unknown handle"); + auto it = getMapping(value).direct.find(value); + assert(it != getMapping(value).direct.end() && "unknown handle"); SmallVector &association = it->getSecond(); SmallVector updated; updated.reserve(association.size()); @@ -98,9 +106,13 @@ for (Value target : transform->getOperands()) removePayloadOps(target); - for (auto &en : llvm::enumerate(transform->getResults())) + for (auto &en : llvm::enumerate(transform->getResults())) { + assert(en.value().getDefiningOp() == transform.getOperation() && + "payload IR association for a value other than the result of the " + "current transform op"); if (failed(setPayloadOps(en.value(), results.get(en.index())))) return failure(); + } return success(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -0,0 +1,101 @@ +//===- TransformDialect.cpp - Transform dialect operations ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/IR/Builders.h" + +#include "mlir/IR/OpImplementation.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" + +LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector targets; + if (getRoot()) + llvm::append_range(targets, state.getPayloadOps(getRoot())); + else + targets.push_back(state.getTopLevel()); + + // Map the entry block argument to the list of operations. + auto scope = state.make_region_scope(*getBodyBlock()->getParent()); + if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets))) + return failure(); + + // Apply the sequenced ops one by one. + for (Operation &transform : getBodyBlock()->without_terminator()) + if (failed(state.applyTransform(cast(transform)))) + return failure(); + + // Forward the operation mapping for values yielded from the sequence to the + // values produced by the sequence op. + for (const auto &pair : + llvm::zip(getBodyBlock()->getTerminator()->getOperands(), + getOperation()->getOpResults())) { + Value terminatorOperand = std::get<0>(pair); + OpResult result = std::get<1>(pair); + results.set(result, state.getPayloadOps(terminatorOperand)); + } + + return success(); +} + +LogicalResult transform::SequenceOp::verify() { + if (getBodyBlock()->getNumArguments() != 1 || + !getBodyBlock()->getArgumentTypes()[0].isa()) { + return emitOpError() + << "expected the entry block to have one argument of type " + << pdl::OperationType::get(getContext()); + } + + if (auto parent = getOperation()->getParentOfType()) { + if (!getRoot()) { + InFlightDiagnostic diag = + emitOpError() + << "expected the root operation to be provided for a nested sequence"; + diag.attachNote(parent.getLoc()) << "nested in another sequence"; + return diag; + } + } + + for (Operation &child : *getBodyBlock()) { + if (!isa(child) && + &child != &getBodyBlock()->back()) { + InFlightDiagnostic diag = + emitOpError() + << "expected children ops to implement TransformOpInterface"; + diag.attachNote(child.getLoc()) << "op without interface"; + return diag; + } + + for (OpResult result : child.getResults()) { + if (llvm::hasNItemsOrLess(result.getUses(), 1)) + continue; + InFlightDiagnostic diag = child.emitError() + << "result #" << result.getResultNumber() + << " has more than one use"; + for (OpOperand &use : result.getUses()) { + diag.attachNote(use.getOwner()->getLoc()) + << "used here as operand #" << use.getOperandNumber(); + } + return diag; + } + } + + if (getBodyBlock()->getTerminator()->getOperandTypes() != + getOperation()->getResultTypes()) { + InFlightDiagnostic diag = emitOpError() + << "expects the types of the terminator operands " + "to match the types of the result"; + diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; + return diag; + } + return success(); +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}} +transform.sequence { +} + +// ----- + +// expected-note @below {{nested in another sequence}} +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expected the root operation to be provided for a nested sequence}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expected children ops to implement TransformOpInterface}} +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{op without interface}} + arith.constant 42.0 : f32 +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{result #0 has more than one use}} + %0 = transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + } : !pdl.operation + // expected-note @below {{used here as operand #0}} + transform.sequence %0 { + ^bb2(%arg2: !pdl.operation): + } + // expected-note @below {{used here as operand #0}} + transform.sequence %0 { + ^bb3(%arg3: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects the types of the terminator operands to match the types of the resul}} +%0 = transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{terminator}} + transform.yield +} : !pdl.operation diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK: transform.sequence +// CHECK: ^{{.+}}(%{{.+}}: !pdl.operation): +transform.sequence { +^bb0(%arg0: !pdl.operation): + // CHECK: sequence %{{.+}} + // CHECK: ^{{.+}}(%{{.+}}: !pdl.operation): + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -25,3 +25,47 @@ %2 = transform.test_produce_param_or_forward_operand from %0 transform.test_consume_operand_if_matches_param_or_fail %1[42] transform.test_consume_operand_if_matches_param_or_fail %2[42] + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + // expected-remark @below {{applying transformation "a"}} + test_transform_op "a" + // expected-remark @below {{applying transformation "b"}} + test_transform_op "b" + // expected-remark @below {{applying transformation "c"}} + test_transform_op "c" + } + // expected-remark @below {{applying transformation "d"}} + test_transform_op "d" + // expected-remark @below {{applying transformation "e"}} + test_transform_op "e" +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + %0 = test_produce_param_or_forward_operand 42 + sequence %0 { + ^bb0(%arg1: !pdl.operation): + // expected-remark @below {{succeeded}} + test_consume_operand_if_matches_param_or_fail %arg1[42] + } +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + %0 = sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %1 = test_produce_param_or_forward_operand 42 + yield %1 : !pdl.operation + } : !pdl.operation + // expected-remark @below {{succeeded}} + test_consume_operand_if_matches_param_or_fail %0[42] +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -38,31 +38,47 @@ LogicalResult apply(transform::TransformResults &results, transform::TransformState &state) { - emitRemark() << "applying transformation"; + InFlightDiagnostic remark = emitRemark() << "applying transformation"; + if (Attribute message = getMessage()) + remark << " " << message; + return success(); } + Attribute getMessage() { return getOperation()->getAttr("message"); } + static ParseResult parse(OpAsmParser &parser, OperationState &state) { - return success(); + StringAttr message; + OptionalParseResult result = parser.parseOptionalAttribute(message); + if (!result.hasValue()) + return success(); + + if (result.getValue().succeeded()) + state.addAttribute("message", message); + return result.getValue(); } - void print(OpAsmPrinter &printer) {} + void print(OpAsmPrinter &printer) { + if (getMessage()) + printer << " " << getMessage(); + } }; } // namespace LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast(), getOperand(0).getDefiningOp()); + results.set(getResult().cast(), + getOperation()->getOperand(0).getDefiningOp()); } else { results.set(getResult().cast(), - reinterpret_cast(*parameter())); + reinterpret_cast(*getParameter())); } return success(); } LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { - if (parameter().hasValue() ^ (getNumOperands() != 1)) + if (getParameter().hasValue() ^ (getNumOperands() != 1)) return emitOpError() << "expects either a parameter or an operand"; return success(); } @@ -72,9 +88,9 @@ ArrayRef payload = state.getPayloadOps(getOperand()); assert(payload.size() == 1 && "expected a single target op"); auto value = reinterpret_cast(payload[0]); - if (static_cast(value) != parameter()) { + if (static_cast(value) != getParameter()) { return emitOpError() << "expected the operand to be associated with " - << parameter() << " got " << value; + << getParameter() << " got " << value; } emitRemark() << "succeeded"; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -37,7 +37,7 @@ void runOnOperation() override { ModuleOp module = getOperation(); - transform::TransformState state(module); + transform::TransformState state(module.getBodyRegion(), module); for (auto op : module.getBody()->getOps()) { if (failed(state.applyTransform(op))) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7699,6 +7699,7 @@ srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]), deps = [ ":OpBaseTdFiles", + ":PDLDialectTdFiles", ], ) @@ -7746,15 +7747,35 @@ deps = [":TransformDialectTdFiles"], ) +gentbl_cc_library( + name = "TransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Transform/IR/TransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Transform/IR/TransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td", + deps = [":TransformDialectTdFiles"], +) + cc_library( name = "TransformDialect", srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]), deps = [ ":IR", + ":PDLDialect", ":Support", ":TransformDialectIncGen", ":TransformDialectInterfacesIncGen", + ":TransformOpsIncGen", "//llvm:Support", ], )