diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -2,6 +2,8 @@ [TOC] +[include "Dialects/TransformTypes.md"] + [include "Dialects/TransformOps.md"] ## Bufferization Transform Operations @@ -16,4 +18,6 @@ [include "Dialects/LinalgStructuredTransformOps.md"] +[include "Dialects/TransformTypeInterfaces.md"] + [include "Dialects/TransformOpInterfaces.md"] diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -8,6 +8,13 @@ add_public_tablegen_target(MLIRTransformDialectIncGen) add_dependencies(mlir-headers MLIRTransformDialectIncGen) +set(LLVM_TARGET_DEFINITIONS TransformTypes.td) +mlir_tablegen(TransformTypes.h.inc -gen-typedef-decls) +mlir_tablegen(TransformTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRTransformTypesIncGen) +add_dependencies(mlir-headers MLIRTransformTypesIncGen) +add_mlir_doc(TransformTypes TransformTypes Dialects/ -gen-typedef-docs) + set(LLVM_TARGET_DEFINITIONS TransformAttrs.td) mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls) mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs) @@ -17,5 +24,13 @@ add_mlir_dialect(TransformOps transform) add_mlir_doc(TransformOps TransformOps Dialects/ -gen-dialect-doc -dialect=transform) +# Contrary to what the name claims, this only produces the _op_ interface. add_mlir_interface(TransformInterfaces) add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interface-docs) + +set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td) +mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRTransformDialectTypeInterfacesIncGen) +add_dependencies(mlir-headers MLIRTransformDialectTypeInterfacesIncGen) +add_mlir_doc(TransformInterfaces TransformTypeInterfaces Dialects/ -gen-type-interface-docs) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" namespace mlir { @@ -37,6 +38,11 @@ "ops injected into the transform dialect must implement " "MemoryEffectsOpInterface"); } + +/// Asserts that the type provided as template argument implements the +/// TransformTypeInterface. This must be a dynamic assertion since interface +/// implementations may be registered at runtime. +void checkImplementsTransformTypeInterface(TypeID typeID, MLIRContext *context); } // namespace detail #endif // NDEBUG } // namespace transform @@ -120,6 +126,18 @@ }); } + /// Injects the types into the Transform dialect. The types must implement + /// the TransformTypeInterface and the implementation must be already + /// available when the type is injected. Furthermore, the types must provide + /// a `getMnemonic` static method returning an object convertible to + /// `StringRef` that is unique across all injected types. + template + void registerTypes() { + opInitializers.push_back([](TransformDialect *transformDialect) { + transformDialect->addTypesChecked(); + }); + } + /// Declares that this Transform dialect extension depends on the dialect /// provided as template parameter. When the Transform dialect is loaded, /// dependent dialects will be loaded as well. This is intended for dialects @@ -182,6 +200,25 @@ bool buildOnly; }; +template +void TransformDialect::addTypeIfNotRegistered() { + // Use the address of the parse method as a proxy for identifying whether we + // are registering the same type class for the same mnemonic. + StringRef mnemonic = Type::getMnemonic(); + auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse); + if (!inserted) { + const ExtensionTypeParsingHook &parsingHook = it->getValue(); + if (*parsingHook.target() != &Type::parse) + reportDuplicateTypeRegistration(mnemonic); + } + typePrintingHooks.try_emplace( + TypeID::get(), +[](mlir::Type type, AsmPrinter &printer) { + printer << Type::getMnemonic(); + cast(type).print(printer); + }); + addTypes(); +} + /// A wrapper for transform dialect extensions that forces them to be /// constructed in the build-only mode. template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -315,6 +315,23 @@ const ::llvm::StringMap<::mlir::PDLConstraintFunction> & getPDLConstraintHooks() const; + /// Parses a type registered by this dialect or one of its extensions. + ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; + + /// Prints a type registered by this dialect or one of its extensions. + void printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &printer) const override; + + /// Parser callback for an individual type registered by this dialect or + /// its extensions. + using ExtensionTypeParsingHook = + std::function<::mlir::Type (::mlir::AsmParser &)>; + + /// Printer callback for an individual type registered by this dialect or + /// its extensions. + using ExtensionTypePrintingHook = + std::function; + private: template void addOperationIfNotRegistered() { @@ -344,6 +361,28 @@ #endif // NDEBUG } + /// Registers the types specified as template parameters with the + /// Transform dialect. Checks that they meet the requirements for + /// Transform IR types. + template + void addTypesChecked() { + (addTypeIfNotRegistered(), ...); + + #ifndef NDEBUG + (detail::checkImplementsTransformTypeInterface( + TypeID::get(), getContext()), ...); + #endif // NDEBUG + } + + /// Implementation of the type registration for a single type, should + /// not be called directly, use addTypesChecked instead. + template + void addTypeIfNotRegistered(); + + /// Reports a repeated registration error of a type with the given + /// mnemonic. + [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic); + template friend class TransformDialectExtension; @@ -352,9 +391,23 @@ void mergeInPDLMatchHooks( ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + //===----------------------------------------------------------------===// + // Data fields + //===----------------------------------------------------------------===// + /// A container for PDL constraint function that can be used by /// operations in this dialect. - PDLPatternModule pdlMatchHooks; + ::mlir::PDLPatternModule pdlMatchHooks; + + /// A map from type mnemonic to its parsing function for the remainder of + /// the syntax. The parser has access to the mnemonic, so it is used for + /// further dispatch. + ::llvm::StringMap typeParsingHooks; + + /// A map from type TypeID to its printing function. No need to do string + /// lookups when the type is fully constructed. + ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook> + typePrintingHooks; }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -12,6 +12,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ScopeExit.h" namespace mlir { @@ -279,13 +280,16 @@ /// list of operations in the payload IR. The arguments must be defined in /// blocks of the currently processed transform IR region, typically after a /// region scope is defined. - void mapBlockArguments(BlockArgument argument, - ArrayRef operations) { + /// + /// Returns failure if the payload does not satisfy the conditions associated + /// with the type of the handle value. + LogicalResult mapBlockArguments(BlockArgument argument, + ArrayRef operations) { #if LLVM_ENABLE_ABI_BREAKING_CHECKS assert(argument.getParentRegion() == regionStack.back() && "mapping block arguments from a region other than the active one"); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - setPayloadOps(argument, operations); + return setPayloadOps(argument, operations); } // Forward declarations to support limited visibility. @@ -478,7 +482,10 @@ /// is invalid given the transformation "consumes" the handle as expressed /// by side effects. Practically, a transformation consuming a handle means /// that the associated payload operation may no longer exist. - void setPayloadOps(Value value, ArrayRef targets); + /// + /// Returns failure if the payload does not satisfy the conditions associated + /// with the type of the handle value. + LogicalResult setPayloadOps(Value value, ArrayRef targets); /// Forgets the payload IR ops associated with the given transform IR value. void removePayloadOps(Value value); @@ -488,8 +495,12 @@ /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR /// value. - void updatePayloadOps(Value value, - function_ref callback); + /// + /// Returns failure if the payload does not satisfy the conditions associated + /// with the type of the handle value. + LogicalResult + updatePayloadOps(Value value, + function_ref callback); /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are @@ -574,9 +585,9 @@ /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait /// to either the list of operations associated with its operand or the root of /// the payload IR, depending on what is available in the context. -void mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op, - Region ®ion); +LogicalResult +mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op, Region ®ion); /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); @@ -613,17 +624,17 @@ /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. - void mapBlockArguments(TransformState &state, Region ®ion) { + LogicalResult mapBlockArguments(TransformState &state, Region ®ion) { assert(region.getParentOp() == this->getOperation() && "op comes from the wrong region"); - detail::mapPossibleTopLevelTransformOpBlockArguments( + return detail::mapPossibleTopLevelTransformOpBlockArguments( state, this->getOperation(), region); } - void mapBlockArguments(TransformState &state) { + LogicalResult mapBlockArguments(TransformState &state) { assert( this->getOperation()->getNumRegions() == 1 && "must indicate the region to map if the operation has more than one"); - mapBlockArguments(state, this->getOperation()->getRegion(0)); + return mapBlockArguments(state, this->getOperation()->getRegion(0)); } }; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -81,6 +81,31 @@ }]; } +def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> { + let description = [{ + Types that can be used for Transform dialect handle values. Such types + define the properties of Payload IR operations associated with the handle. + A user of such a handle can assume that these properties have been verified + for any Payload IR operation associated with it. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Checks if the given list of associated Payload IR operations satisfy + the conditions defined by this type. If not, produces a silenceable + error at the specified location. + }], + /*returnType=*/"::mlir::DiagnosedSilenceableFailure", + /*name=*/"checkPayload", + /*arguments=*/(ins "::mlir::Location":$loc, + "::mlir::ArrayRef<::mlir::Operation *>":$payload) + > + ]; +} + def FunctionalStyleTransformOpTrait : NativeOpTrait<"FunctionalStyleTransformOpTrait"> { let cppNamespace = "::mlir::transform"; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -10,10 +10,11 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #include "mlir/Dialect/PDL/IR/PDLTypes.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -96,6 +97,23 @@ let hasVerifier = 1; } +def CastOp : TransformDialectOp<"cast", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + // TODO: temporarily fallback support for casting from PDL_Operation type. + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h @@ -0,0 +1,26 @@ +//===- TransformTypes.h - Transform dialect types ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class DiagnosedSilenceableFailure; +class Operation; +class Type; +} // namespace mlir + +#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Transform/IR/TransformTypes.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -0,0 +1,26 @@ +//===- TransformTypes.td - Transform dialect types ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" + +def Transform_AnyOpType : TypeDef]> { + let description = [{ + Transform IR handle that can be associated with a list of arbitrary + Payload IR operations. + }]; + let mnemonic = "any_op"; + let assemblyFormat = ""; +} + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -2,6 +2,7 @@ TransformDialect.cpp TransformInterfaces.cpp TransformOps.cpp + TransformTypes.cpp DEPENDS MLIRTransformDialectIncGen @@ -9,6 +10,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRParser MLIRPDLDialect MLIRPDLInterpDialect MLIRRewrite diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -10,18 +10,32 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/DialectImplementation.h" using namespace mlir; #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" +#ifndef NDEBUG +void transform::detail::checkImplementsTransformTypeInterface( + TypeID typeID, MLIRContext *context) { + const auto &abstractType = AbstractType::lookup(typeID, context); + assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID())); +} +#endif // NDEBUG + void transform::TransformDialect::initialize() { - // Using the checked version to enable the same assertions as for the ops from - // extensions. + // Using the checked versions to enable the same assertions as for the ops + // from extensions. addOperationsChecked< #define GET_OP_LIST #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); + addTypesChecked< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" + >(); } void transform::TransformDialect::mergeInPDLMatchHooks( @@ -36,4 +50,36 @@ return pdlMatchHooks.getConstraintFunctions(); } +Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + SMLoc loc = parser.getCurrentLocation(); + if (failed(parser.parseKeyword(&keyword))) + return nullptr; + + auto it = typeParsingHooks.find(keyword); + if (it == typeParsingHooks.end()) { + parser.emitError(loc) << "unknown type mnemonic: " << keyword; + return nullptr; + } + + return it->getValue()(parser); +} + +void transform::TransformDialect::printType(Type type, + DialectAsmPrinter &printer) const { + auto it = typePrintingHooks.find(type.getTypeID()); + assert(it != typePrintingHooks.end() && "printing unknown type"); + it->getSecond()(type, printer); +} + +void transform::TransformDialect::reportDuplicateTypeRegistration( + StringRef mnemonic) { + std::string buffer; + llvm::raw_string_ostream msg(buffer); + msg << "error: extensible dialect type '" << mnemonic + << "' is already registered with a different implementation"; + msg.flush(); + llvm::report_fatal_error(StringRef(buffer)); +} + #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/STLExtras.h" @@ -60,14 +61,22 @@ return success(found); } -void transform::TransformState::setPayloadOps(Value value, - ArrayRef targets) { +LogicalResult +transform::TransformState::setPayloadOps(Value value, + ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); // TODO: this may go now if (value.use_empty()) - return; + return success(); + + if (auto iface = value.getType().dyn_cast()) { + DiagnosedSilenceableFailure result = + iface.checkPayload(value.getLoc(), targets); + if (failed(result.checkAndReport())) + return failure(); + } // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. @@ -80,6 +89,8 @@ for (Operation *op : targets) mappings.reverse[op].push_back(value); + + return success(); } void transform::TransformState::dropReverseMapping(Mappings &mappings, @@ -100,7 +111,7 @@ mappings.direct.erase(value); } -void transform::TransformState::updatePayloadOps( +LogicalResult transform::TransformState::updatePayloadOps( Value value, function_ref callback) { Mappings &mappings = getMapping(value); auto it = mappings.direct.find(value); @@ -117,7 +128,15 @@ } } + if (auto iface = value.getType().dyn_cast()) { + DiagnosedSilenceableFailure result = + iface.checkPayload(value.getLoc(), updated); + if (failed(result.checkAndReport())) + return failure(); + } + std::swap(association, updated); + return success(); } void transform::TransformState::recordHandleInvalidationOne( @@ -253,7 +272,8 @@ assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - setPayloadOps(result, results.get(result.getResultNumber())); + if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) + return DiagnosedSilenceableFailure::definiteFailure(); } printOnFailureRAII.release(); @@ -278,9 +298,12 @@ return failure(); for (Value handle : handles) { - state.updatePayloadOps(handle, [&](Operation *current) { - return current == op ? replacement : current; - }); + LogicalResult result = + state.updatePayloadOps(handle, [&](Operation *current) { + return current == op ? replacement : current; + }); + if (failed(result)) + return failure(); } return success(); } @@ -317,7 +340,7 @@ // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// -void transform::detail::mapPossibleTopLevelTransformOpBlockArguments( +LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; if (op->getNumOperands() != 0) @@ -325,7 +348,7 @@ else targets.push_back(state.getTopLevel()); - state.mapBlockArguments(region.front().getArgument(0), targets); + return state.mapBlockArguments(region.front().getArgument(0), targets); } LogicalResult diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" @@ -226,7 +227,8 @@ for (Operation *clone : clones) clone->erase(); }); - state.mapBlockArguments(reg.front().getArgument(0), clones); + if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) + return DiagnosedSilenceableFailure::definiteFailure(); bool failed = false; for (Operation &transform : reg.front().without_terminator()) { @@ -291,6 +293,35 @@ // ForeachOp //===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure +transform::CastOp::applyToOne(Operation *target, + SmallVectorImpl &results, + transform::TransformState &state) { + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +void transform::CastOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsPayload(effects); + consumesHandle(getInput(), effects); + producesHandle(getOutput(), effects); +} + +bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + assert(inputs.size() == 1 && "expected one input"); + assert(outputs.size() == 1 && "expected one output"); + return llvm::all_of( + std::initializer_list{inputs.front(), outputs.front()}, + [](Type ty) { + return ty.isa(); + }); +} + +//===----------------------------------------------------------------------===// +// ForeachOp +//===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -299,7 +330,8 @@ for (Operation *op : payloadOps) { auto scope = state.make_region_scope(getBody()); - state.mapBlockArguments(getIterationVariable(), {op}); + if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) + return DiagnosedSilenceableFailure::definiteFailure(); // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { @@ -572,7 +604,8 @@ transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); - mapBlockArguments(state); + if (failed(mapBlockArguments(state))) + return DiagnosedSilenceableFailure::definiteFailure(); // Apply the sequenced ops one by one. for (Operation &transform : getBodyBlock()->without_terminator()) { @@ -766,7 +799,8 @@ [&]() { state.removeExtension(); }); auto scope = state.make_region_scope(getBody()); - mapBlockArguments(state); + if (failed(mapBlockArguments(state))) + return DiagnosedSilenceableFailure::definiteFailure(); return state.applyTransform(transformOp); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -0,0 +1,37 @@ +//===- TransformTypes.cpp - Transform Dialect Type Definitions ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Compiler.h" + +using namespace mlir; + +#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc" + +// These are automatically generated by ODS but are not used as the Transform +// dialect uses a different dispatch mechanism to support dialect extensions. +LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); +LLVM_ATTRIBUTE_UNUSED static LogicalResult +generatedTypePrinter(Type def, AsmPrinter &printer); + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" + +DiagnosedSilenceableFailure +transform::AnyOpType::checkPayload(Location loc, + ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); +} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -58,3 +58,10 @@ ^bb1(%arg1: !pdl.operation): } } + +// CHECK: transform.sequence +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // CHECK: cast %{{.*}} : !pdl.operation to !transform.any_op + %0 = cast %arg0: !pdl.operation to !transform.any_op +} diff --git a/mlir/test/Dialect/Transform/test-dialect-injection.mlir b/mlir/test/Dialect/Transform/test-dialect-injection.mlir --- a/mlir/test/Dialect/Transform/test-dialect-injection.mlir +++ b/mlir/test/Dialect/Transform/test-dialect-injection.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s | FileCheck %s -// These ops are defined by a test extension but should be okay to roundtrip. +// These types and ops are defined by a test extension but should be okay to +// roundtrip. // CHECK: transform.test_transform_op transform.test_transform_op @@ -10,3 +11,7 @@ // CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42] transform.test_consume_operand_if_matches_param_or_fail %0[42] + +// Ensure that the extension type is roundtripped correctly. +// CHECK: transform.cast %{{.*}} : !pdl.operation to !transform.test_dialect_op +%1 = transform.cast %0: !pdl.operation to !transform.test_dialect_op diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -798,5 +798,46 @@ // Silenceable failure and all handles are now empty. %h_2:3 = split_handles %muli_2 in [3] // expected-remark @below {{0}} - transform.test_print_number_of_associated_payload_ir_ops %h_2#0 + transform.test_print_number_of_associated_payload_ir_ops %h_2#0 +} + +// ----- + +"test.some_op"() : () -> () +"other_dialect.other_op"() : () -> () + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op + transform.cast %2 : !transform.test_dialect_op to !pdl.operation + } +} + +// ----- + +"test.some_op"() : () -> () +"other_dialect.other_op"() : () -> () + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @other : benefit(1) { + %0 = pdl.operation "other_dialect.other_op" + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @other in %arg1 + // expected-error @below {{expected the payload operation to belong to the 'test' dialect}} + %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op + transform.cast %2 : !transform.test_dialect_op to !pdl.operation + } } diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -1,6 +1,8 @@ set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td) mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls) mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs) +mlir_tablegen(TestTransformDialectExtensionTypes.h.inc -gen-typedef-decls -typedefs-dialect=transform) +mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=transform) add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen) add_mlir_library(MLIRTestTransformDialect diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" namespace mlir { @@ -25,6 +26,9 @@ #define GET_OP_CLASSES #include "TestTransformDialectExtension.h.inc" +#define GET_TYPEDEF_CLASSES +#include "TestTransformDialectExtensionTypes.h.inc" + namespace test { /// Registers the test extension to the Transform dialect. void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -17,6 +17,8 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Compiler.h" using namespace mlir; @@ -310,6 +312,22 @@ return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( + Location loc, ArrayRef payload) const { + if (payload.empty()) + return DiagnosedSilenceableFailure::success(); + + for (Operation *op : payload) { + if (op->getName().getDialectNamespace() != "test") { + Diagnostic diag(loc, DiagnosticSeverity::Error); + diag << "expected the payload operation to belong to the 'test' dialect"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + } + + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL @@ -327,6 +345,10 @@ #define GET_OP_LIST #include "TestTransformDialectExtension.cpp.inc" >(); + registerTypes< +#define GET_TYPEDEF_LIST +#include "TestTransformDialectExtensionTypes.cpp.inc" + >(); } }; } // namespace @@ -334,6 +356,16 @@ #define GET_OP_CLASSES #include "TestTransformDialectExtension.cpp.inc" +// These are automatically generated by ODS but are not used as the Transform +// dialect uses a different dispatch mechanism to support dialect extensions. +LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); +LLVM_ATTRIBUTE_UNUSED static LogicalResult +generatedTypePrinter(Type def, AsmPrinter &printer); + +#define GET_TYPEDEF_CLASSES +#include "TestTransformDialectExtensionTypes.cpp.inc" + void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -14,12 +14,21 @@ #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD #define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" +def TestTransformTestDialectHandleType + : TypeDef]> { + let description = [{Handle pointing to an op from the Test dialect.}]; + let mnemonic = "test_dialect_op"; + let assemblyFormat = ""; +} + def TestProduceParamOrForwardOperandOp : Op]> { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8396,6 +8396,7 @@ name = "TransformDialectTdFiles", srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]), deps = [ + ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", ":OpBaseTdFiles", ":PDLDialectTdFiles", @@ -8441,6 +8442,18 @@ ], "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc", ), + ( + [ + "-gen-type-interface-decls", + ], + "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc", + ), + ( + [ + "-gen-type-interface-defs", + ], + "include/mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td", @@ -8487,6 +8500,24 @@ deps = [":TransformDialectTdFiles"], ) +gentbl_cc_library( + name = "TransformTypesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "include/mlir/Dialect/Transform/IR/TransformTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "include/mlir/Dialect/Transform/IR/TransformTypes.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/IR/TransformTypes.td", + deps = [":TransformDialectTdFiles"], +) + cc_library( name = "TransformDialect", srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]), @@ -8503,6 +8534,7 @@ ":TransformDialectIncGen", ":TransformDialectInterfacesIncGen", ":TransformOpsIncGen", + ":TransformTypesIncGen", "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -266,6 +266,20 @@ ["-gen-op-defs"], "lib/Dialect/Transform/TestTransformDialectExtension.cpp.inc", ), + ( + [ + "-gen-typedef-decls", + "-typedefs-dialect=transform", + ], + "lib/Dialect/Transform/TestTransformDialectExtensionTypes.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "-typedefs-dialect=transform", + ], + "lib/Dialect/Transform/TestTransformDialectExtensionTypes.cpp.inc", + ), ], tblgen = "//mlir:mlir-tblgen", td_file = "lib/Dialect/Transform/TestTransformDialectExtension.td", @@ -284,6 +298,7 @@ includes = ["lib/Dialect/Transform"], deps = [ ":TestTransformDialectExtensionIncGen", + "//llvm:Support", "//mlir:IR", "//mlir:PDLDialect", "//mlir:Pass",