diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -498,6 +498,9 @@ /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); + +/// Verification hook for TransformOpInterface. +LogicalResult verifyTransformOpInterface(Operation *op); } // namespace detail /// This trait is supposed to be attached to Transform dialect operations that diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -101,6 +101,10 @@ return diag; } }]; + + let verify = [{ + return ::mlir::transform::detail::verifyTransformOpInterface($_op); + }]; } class TransformTypeInterfaceBase diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1876,6 +1876,8 @@ consumesHandle(getTarget(), effects); onlyReadsHandle(getTileSizes(), effects); onlyReadsHandle(getNumThreads(), effects); + onlyReadsHandle(getPackedNumThreads(), effects); + onlyReadsHandle(getPackedTileSizes(), effects); producesHandle(getResults(), effects); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -623,8 +623,8 @@ /// Returns `true` if the given list of effects instances contains an instance /// with the effect type specified as template parameter. -template -static bool hasEffect(ArrayRef effects) { +template +static bool hasEffect(Range &&effects) { return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa(effect.getEffect()) && isa(effect.getResource()); @@ -671,6 +671,48 @@ effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); } +//===----------------------------------------------------------------------===// +// Utilities for TransformOpInterface. +//===----------------------------------------------------------------------===// + +LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { + auto iface = cast(op); + SmallVector effects; + iface.getEffects(effects); + + auto effectsOn = [&](Value value) { + return llvm::make_filter_range( + effects, [value](const MemoryEffects::EffectInstance &instance) { + return instance.getValue() == value; + }); + }; + + for (OpOperand &operand : op->getOpOperands()) { + auto range = effectsOn(operand.get()); + if (range.empty()) { + InFlightDiagnostic diag = + op->emitError() << "TransformOpInterface requires memory effects " + "on operands to be specified"; + diag.attachNote() << "no effects specified for operand #" + << operand.getOperandNumber(); + return diag; + } + } + for (OpResult result : op->getResults()) { + auto range = effectsOn(result); + if (!::hasEffect( + range)) { + InFlightDiagnostic diag = + op->emitError() << "TransformOpInterface requires 'allocate' memory " + "effect to be specified for results"; + diag.attachNote() << "no 'allocate' effect specified for result #" + << result.getResultNumber(); + return diag; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Entry point. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -210,3 +210,21 @@ // expected-note @below {{used here as operand #0}} transform.test_consume_operand %0 } + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}} + // expected-note @below {{no effects specified for operand #0}} + transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}} + // expected-note @below {{no 'allocate' effect specified for result #0}} + transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -471,7 +471,9 @@ } void mlir::test::TestProduceNullParamOp::getEffects( - SmallVectorImpl &effects) {} + SmallVectorImpl &effects) { + transform::producesHandle(getOut(), effects); +} DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, @@ -480,6 +482,23 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestRequiredMemoryEffectsOp::getEffects( + SmallVectorImpl &effects) { + if (getHasOperandEffect()) + transform::consumesHandle(getIn(), effects); + + if (getHasResultEffect()) + transform::producesHandle(getOut(), effects); + else + transform::onlyReadsHandle(getOut(), effects); +} + +DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + results.set(getOut().cast(), state.getPayloadOps(getIn())); + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -352,4 +352,16 @@ let cppNamespace = "::mlir::test"; } +def TestRequiredMemoryEffectsOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$in, + UnitAttr:$has_operand_effect, + UnitAttr:$has_result_effect); + let results = (outs TransformHandleTypeInterface:$out); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD