diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td @@ -11,7 +11,6 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -11,7 +11,6 @@ include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -10,7 +10,6 @@ #define GPU_TRANSFORM_OPS include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -10,7 +10,6 @@ #define LINALG_TRANSFORM_OPS include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" @@ -89,7 +88,8 @@ def FuseIntoContainingOp : Op]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Fuse a producer into a containing operation."; let description = [{ @@ -125,14 +125,9 @@ This operation reads the containing op handle. }]; - let arguments = (ins Arg:$producer_op, - Arg:$containing_op); - let results = (outs Res:$fused_op); + let arguments = (ins PDL_Operation:$producer_op, + PDL_Operation:$containing_op); + let results = (outs PDL_Operation:$fused_op); let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; let builders = [ diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -10,7 +10,6 @@ #define MEMREF_TRANSFORM_OPS include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -10,7 +10,6 @@ #define SCF_TRANSFORM_OPS include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td b/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformEffects.td +++ /dev/null @@ -1,62 +0,0 @@ - -//===- TransformEffect.td - Transform side effects ---------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines side effects and associated resources for operations in the -// Transform dialect and extensions. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD -#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD - -include "mlir/Interfaces/SideEffectInterfaces.td" - -//===----------------------------------------------------------------------===// -// Effects on the mapping between Transform IR values and Payload IR ops. -//===----------------------------------------------------------------------===// - -// Side effect resource corresponding to the mapping between transform IR values -// and Payload IR operations. -def TransformMappingResource - : Resource<"::mlir::transform::TransformMappingResource">; - -// Describes the creation of a new entry in the transform mapping. Should be -// accompanied by the Write effect as the entry is immediately initialized by -// any reasonable transform operation. -def TransformMappingAlloc : MemAlloc; - -// Describes the removal of an entry in the transform mapping. Typically -// accompanied by the Read effect. -def TransformMappingFree : MemFree; - -// Describes the access to the mapping. Read-only accesses can be reordered. -def TransformMappingRead : MemRead; - -// Describes a modification of an existing entry in the mapping. It is rarely -// used alone, and is mostly accompanied by the Allocate effect. -def TransformMappingWrite : MemWrite; - -//===----------------------------------------------------------------------===// -// Effects on Payload IR. -//===----------------------------------------------------------------------===// - -// Side effect resource corresponding to the Payload IR itself. -def PayloadIRResource : Resource<"::mlir::transform::PayloadIRResource">; - -// Corresponds to the read-only access to the Payload IR through some operation -// handles in the Transform IR. -def PayloadIRRead : MemRead; - -// Corresponds to the mutation of the Payload IR through an operation handle in -// the Transform IR. Should be accompanied by the Read effect for most transform -// operations (only a complete overwrite of the root op of the Payload IR is a -// write-only modification). -def PayloadIRWrite : MemWrite; - -#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_EFFECTS_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -12,11 +12,11 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/Transform/IR/TransformAttrs.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" def AlternativesOp : TransformDialectOp<"alternatives", @@ -466,7 +466,8 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", [DeclareOpInterfaceMethods, NoTerminator, - OpAsmOpInterface, PossibleTopLevelTransformOpTrait, RecursiveMemoryEffects, + OpAsmOpInterface, PossibleTopLevelTransformOpTrait, + DeclareOpInterfaceMethods, SymbolTable]> { let summary = "Contains PDL patterns available for use in transforms"; let description = [{ @@ -505,8 +506,8 @@ }]; let arguments = (ins - Arg, "Root operation of the Payload IR", - [TransformMappingRead]>:$root); + Arg, "Root operation of the Payload IR" + >:$root); let regions = (region SizedRegion<1>:$body); let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; @@ -518,7 +519,8 @@ }]; } -def YieldOp : TransformDialectOp<"yield", [Terminator]> { +def YieldOp : TransformDialectOp<"yield", + [Terminator, DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; let description = [{ This terminator operation yields operation handles from regions of the @@ -527,8 +529,8 @@ }]; let arguments = (ins - Arg, "Operation handles yielded back to the parent", - [TransformMappingRead]>:$operands); + Arg, "Operation handles yielded back to the parent" + >:$operands); let assemblyFormat = "operands attr-dict (`:` type($operands)^)?"; let builders = [ diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -10,7 +10,6 @@ #define VECTOR_TRANSFORM_OPS include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -60,13 +60,14 @@ void transform::OneShotBufferizeOp::getEffects( SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - TransformMappingResource::get()); - // Handles that are not modules are not longer usable. - if (!getTargetIsModule()) - effects.emplace_back(MemoryEffects::Free::get(), getTarget(), - TransformMappingResource::get()); + if (!getTargetIsModule()) { + consumesHandle(getTarget(), effects); + } else { + onlyReadsHandle(getTarget(), effects); + } + + modifiesPayload(effects); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -713,6 +713,14 @@ return DiagnosedSilenceableFailure::success(); } +void transform::FuseIntoContainingOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getProducerOp(), effects); + onlyReadsHandle(getContainingOp(), effects); + producesHandle(getFusedOp(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// @@ -2668,6 +2676,7 @@ onlyReadsHandle(getPackedNumThreads(), effects); onlyReadsHandle(getPackedTileSizes(), effects); producesHandle(getResults(), effects); + modifiesPayload(effects); } SmallVector TileToForeachThreadOp::getMixedNumThreads() { @@ -2997,6 +3006,7 @@ SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); onlyReadsHandle(getVectorSizes(), effects); + modifiesPayload(effects); } SmallVector MaskedVectorizeOp::getMixedVectorSizes() { diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -783,6 +783,7 @@ }); }; + std::optional firstConsumedOperand = std::nullopt; for (OpOperand &operand : op->getOpOperands()) { auto range = effectsOn(operand.get()); if (range.empty()) { @@ -793,7 +794,30 @@ << operand.getOperandNumber(); return diag; } + if (::hasEffect(range)) { + InFlightDiagnostic diag = op->emitError() + << "TransformOpInterface did not expect " + "'allocate' memory effect on an operand"; + diag.attachNote() << "specified for operand #" + << operand.getOperandNumber(); + return diag; + } + if (!firstConsumedOperand && + ::hasEffect(range)) { + firstConsumedOperand = operand.getOperandNumber(); + } + } + + if (firstConsumedOperand && + !::hasEffect(effects)) { + InFlightDiagnostic diag = + op->emitError() + << "TransformOpInterface expects ops consuming operands to have a " + "'write' effect on the payload resource"; + diag.attachNote() << "consumes operand #" << *firstConsumedOperand; + return diag; } + for (OpResult result : op->getResults()) { auto range = effectsOn(result); if (!::hasEffect( @@ -806,6 +830,7 @@ return diag; } } + return success(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -292,7 +292,7 @@ void transform::CastOp::getEffects( SmallVectorImpl &effects) { onlyReadsPayload(effects); - consumesHandle(getInput(), effects); + onlyReadsHandle(getInput(), effects); producesHandle(getOutput(), effects); } @@ -501,7 +501,7 @@ void transform::MergeHandlesOp::getEffects( SmallVectorImpl &effects) { - consumesHandle(getHandles(), effects); + onlyReadsHandle(getHandles(), effects); producesHandle(getResult(), effects); // There are no effects on the Payload IR as this is only a handle @@ -557,7 +557,7 @@ void transform::SplitHandlesOp::getEffects( SmallVectorImpl &effects) { - consumesHandle(getHandle(), effects); + onlyReadsHandle(getHandle(), effects); producesHandle(getResults(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. @@ -626,7 +626,7 @@ void transform::ReplicateOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getPattern(), effects); - consumesHandle(getHandles(), effects); + onlyReadsHandle(getHandles(), effects); producesHandle(getReplicated(), effects); } @@ -832,34 +832,62 @@ effects.emplace_back(effect.getEffect(), target, effect.getResource()); } -void transform::SequenceOp::getEffects( - SmallVectorImpl &effects) { - onlyReadsHandle(getRoot(), effects); - onlyReadsHandle(getExtraBindings(), effects); - producesHandle(getResults(), effects); +namespace { +template +using has_get_extra_bindings = decltype(std::declval().getExtraBindings()); +} // namespace + +/// Populate `effects` with transform dialect memory effects for the potential +/// top-level operation. Such operations have recursive effects from nested +/// operations. When they have an operand, we can additionally remap effects on +/// the block argument to be effects on the operand. +template +static void getPotentialTopLevelEffects( + OpTy operation, SmallVectorImpl &effects) { + transform::onlyReadsHandle(operation->getOperands(), effects); + transform::producesHandle(operation->getResults(), effects); + + if (!operation.getRoot()) { + for (Operation &op : *operation.getBodyBlock()) { + auto iface = dyn_cast(&op); + if (!iface) + continue; - if (!getRoot()) { - for (Operation &op : *getBodyBlock()) { - auto iface = cast(&op); SmallVector nestedEffects; iface.getEffects(effects); } return; } - // Carry over all effects on the argument of the entry block as those on the - // operand, this is the same value just remapped. - for (Operation &op : *getBodyBlock()) { - auto iface = cast(&op); + // Carry over all effects on arguments of the entry block as those on the + // operands, this is the same value just remapped. + for (Operation &op : *operation.getBodyBlock()) { + auto iface = dyn_cast(&op); + if (!iface) + continue; - remapEffects(iface, getBodyBlock()->getArgument(0), getRoot(), effects); - for (auto [source, target] : llvm::zip( - getBodyBlock()->getArguments().drop_front(), getExtraBindings())) { - remapEffects(iface, source, target, effects); + remapEffects(iface, operation.getBodyBlock()->getArgument(0), + operation.getRoot(), effects); + if constexpr (llvm::is_detected::value) { + for (auto [source, target] : + llvm::zip(operation.getBodyBlock()->getArguments().drop_front(), + operation.getExtraBindings())) { + remapEffects(iface, source, target, effects); + } } + + SmallVector nestedEffects; + iface.getEffectsOnResource(transform::PayloadIRResource::get(), + nestedEffects); + llvm::append_range(effects, nestedEffects); } } +void transform::SequenceOp::getEffects( + SmallVectorImpl &effects) { + getPotentialTopLevelEffects(*this, effects); +} + OperandRange transform::SequenceOp::getSuccessorEntryOperands( std::optional index) { assert(index && *index == 0 && "unexpected region index"); @@ -983,6 +1011,11 @@ return state.applyTransform(transformOp); } +void transform::WithPDLPatternsOp::getEffects( + SmallVectorImpl &effects) { + getPotentialTopLevelEffects(*this, effects); +} + LogicalResult transform::WithPDLPatternsOp::verify() { Block *body = getBodyBlock(); Operation *topLevelOp = nullptr; @@ -1065,3 +1098,12 @@ // writes into the default resource. effects.emplace_back(MemoryEffects::Write::get()); } + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +void transform::YieldOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getOperands(), effects); +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -251,7 +251,7 @@ ^bb0(%arg0: !transform.any_op): // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}} // expected-note @below {{no effects specified for operand #0}} - transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op + transform.test_required_memory_effects %arg0 {modifies_payload} : (!transform.any_op) -> !transform.any_op } // ----- @@ -260,5 +260,5 @@ ^bb0(%arg0: !transform.any_op): // expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}} // expected-note @below {{no 'allocate' effect specified for result #0}} - transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op + transform.test_required_memory_effects %arg0 {has_operand_effect, modifies_payload} : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -118,6 +118,13 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestProduceParamOrForwardOperandOp::getEffects( + SmallVectorImpl &effects) { + if (getOperand()) + transform::onlyReadsHandle(getOperand(), effects); + transform::producesHandle(getRes(), effects); +} + LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { if (getParameter().has_value() ^ (getNumOperands() != 1)) return emitOpError() << "expects either a parameter or an operand"; @@ -130,6 +137,14 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestConsumeOperand::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getOperand(), effects); + if (getSecondOperand()) + transform::consumesHandle(getSecondOperand(), effects); + transform::modifiesPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { @@ -146,6 +161,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestConsumeOperandIfMatchesParamOrFail::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getOperand(), effects); + transform::modifiesPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); @@ -155,6 +176,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestPrintRemarkAtOperandOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOperand(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -187,6 +214,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOperand(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); @@ -199,6 +232,12 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOperand(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); @@ -312,6 +351,13 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestCopyPayloadOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getHandle(), effects); + transform::producesHandle(getCopy(), effects); + transform::onlyReadsPayload(effects); +} + DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( Location loc, ArrayRef payload) const { if (payload.empty()) @@ -491,6 +537,9 @@ transform::producesHandle(getOut(), effects); else transform::onlyReadsHandle(getOut(), effects); + + if (getModifiesPayload()) + transform::modifiesPayload(effects); } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -14,10 +14,10 @@ #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD #define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" @@ -41,35 +41,33 @@ def TestProduceParamOrForwardOperandOp : Op]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins - Arg, "", [TransformMappingRead]>:$operand, + Optional:$operand, OptionalAttr:$parameter); - let results = (outs - Res:$res); + let results = (outs PDL_Operation:$res); let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict"; let cppNamespace = "::mlir::test"; let hasVerifier = 1; } def TestConsumeOperand : Op]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins - Arg:$operand, - Arg, "", - [TransformMappingRead, TransformMappingFree]>:$second_operand); + PDL_Operation:$operand, + Optional:$second_operand); let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict"; let cppNamespace = "::mlir::test"; } def TestConsumeOperandIfMatchesParamOrFail : Op]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins - Arg:$operand, + PDL_Operation:$operand, I64Attr:$parameter); let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; let cppNamespace = "::mlir::test"; @@ -77,10 +75,10 @@ def TestPrintRemarkAtOperandOp : Op]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins - Arg:$operand, + TransformHandleTypeInterface:$operand, StrAttr:$message); let assemblyFormat = "$operand `,` $message attr-dict `:` type($operand)"; @@ -98,19 +96,18 @@ def TestCheckIfTestExtensionPresentOp : Op]> { - let arguments = (ins - Arg:$operand); + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$operand); let assemblyFormat = "$operand attr-dict"; let cppNamespace = "::mlir::test"; } def TestRemapOperandPayloadToSelfOp : Op]> { - let arguments = (ins - Arg:$operand); + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$operand); let assemblyFormat = "$operand attr-dict"; let cppNamespace = "::mlir::test"; } @@ -255,10 +252,10 @@ def TestCopyPayloadOp : Op]> { - let arguments = (ins Arg:$handle); - let results = (outs Res:$copy); + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$handle); + let results = (outs PDL_Operation:$copy); let cppNamespace = "::mlir::test"; let assemblyFormat = "$handle attr-dict"; } @@ -358,7 +355,8 @@ DeclareOpInterfaceMethods]> { let arguments = (ins TransformHandleTypeInterface:$in, UnitAttr:$has_operand_effect, - UnitAttr:$has_result_effect); + UnitAttr:$has_result_effect, + UnitAttr:$modifies_payload); let results = (outs TransformHandleTypeInterface:$out); let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; let cppNamespace = "::mlir::test";