diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -845,6 +845,27 @@ return res; } } // namespace detail + +/// Populates `effects` with the memory effects indicating the operation on the +/// given handle value: +/// - consumes = Read + Free, +/// - produces = Allocate + Write, +/// - onlyReads = Read. +void consumesHandle(ValueRange handles, + SmallVectorImpl &effects); +void producesHandle(ValueRange handles, + SmallVectorImpl &effects); +void onlyReadsHandle(ValueRange handles, + SmallVectorImpl &effects); + +/// Checks whether the transform op consumes the given handle. +bool isHandleConsumed(Value handle, transform::TransformOpInterface transform); + +/// Populates `effects` with the memory effects indicating the access to payload +/// IR resource. +void modifiesPayload(SmallVectorImpl &effects); +void onlyReadsPayload(SmallVectorImpl &effects); + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -174,6 +174,42 @@ let assemblyFormat = "$pattern_name `in` $root attr-dict"; } +def ReplicateOp : TransformDialectOp<"replicate", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Lists payload ops multiple times in the new handle"; + let description = [{ + Produces a new handle associated with a list of payload IR ops that is + computed by repeating the list of payload IR ops associated with the + operand handle as many times as the "pattern" handle has associated + operations. For example, if pattern is associated with [op1, op2] and the + operand handle is associated with [op3, op4, op5], the resulting handle + will be associated with [op3, op4, op5, op3, op4, op5]. + + This transformation is useful to "align" the sizes of payload IR lists + before a transformation that expects, e.g., identically-sized lists. For + example, a transformation may be parameterized by same notional per-target + size computed at runtime and supplied as another handle, the replication + allows this size to be computed only once and used for every target instead + of replicating the computation itself. + + Note that it is undesirable to pass a handle with duplicate operations to + an operation that consumes the handle. Handle consumption often indicates + that the associated payload IR ops are destroyed, so having the same op + listed more than once will lead to double-free. Single-operand + MergeHandlesOp may be used to deduplicate the associated list of payload IR + ops when necessary. Furthermore, a combination of ReplicateOp and + MergeHandlesOp can be used to construct arbitrary lists with repetitions. + }]; + + let arguments = (ins PDL_Operation:$pattern, + Variadic:$handles); + let results = (outs Variadic:$replicated); + let assemblyFormat = + "`num` `(` $pattern `)` $handles " + "custom(type($replicated), ref($handles)) attr-dict"; +} + def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethodssecond != handle) { InFlightDiagnostic diag = operation->emitError() << "operation tracked by two handles"; diag.attachNote(handle.getLoc()) << "handle"; @@ -191,9 +191,27 @@ DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); - if (options.getExpensiveChecksEnabled() && - failed(checkAndRecordHandleInvalidation(transform))) { - return DiagnosedSilenceableFailure::definiteFailure(); + if (options.getExpensiveChecksEnabled()) { + if (failed(checkAndRecordHandleInvalidation(transform))) + return DiagnosedSilenceableFailure::definiteFailure(); + + for (OpOperand &operand : transform->getOpOperands()) { + if (!isHandleConsumed(operand.get(), transform)) + continue; + + DenseSet seen; + for (Operation *op : getPayloadOps(operand.get())) { + if (!seen.insert(op).second) { + DiagnosedSilenceableFailure diag = + transform.emitSilenceableError() + << "a handle passed as operand #" << operand.getOperandNumber() + << " and consumed by this operation points to a payload " + "operation more than once"; + diag.attachNote(op->getLoc()) << "repeated target op"; + return diag; + } + } + } } transform::TransformResults results(transform->getNumResults()); @@ -326,6 +344,70 @@ return success(); } +//===----------------------------------------------------------------------===// +// Memory effects. +//===----------------------------------------------------------------------===// + +void transform::consumesHandle( + ValueRange handles, + SmallVectorImpl &effects) { + for (Value handle : handles) { + effects.emplace_back(MemoryEffects::Read::get(), handle, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), handle, + TransformMappingResource::get()); + } +} + +/// Returns `true` if the given list of effects instances contains an instance +/// with the effect type specified as template parameter. +template +static bool hasEffect(ArrayRef effects) { + return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()); + }); +} + +bool transform::isHandleConsumed(Value handle, + transform::TransformOpInterface transform) { + auto iface = cast(transform.getOperation()); + SmallVector effects; + iface.getEffectsOnValue(handle, effects); + return hasEffect(effects) && + hasEffect(effects); +} + +void transform::producesHandle( + ValueRange handles, + SmallVectorImpl &effects) { + for (Value handle : handles) { + effects.emplace_back(MemoryEffects::Allocate::get(), handle, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), handle, + TransformMappingResource::get()); + } +} + +void transform::onlyReadsHandle( + ValueRange handles, + SmallVectorImpl &effects) { + for (Value handle : handles) { + effects.emplace_back(MemoryEffects::Read::get(), handle, + TransformMappingResource::get()); + } +} + +void transform::modifiesPayload( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); +} + +void transform::onlyReadsPayload( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); +} + //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -23,6 +23,16 @@ using namespace mlir; +static ParseResult parsePDLOpTypedResults( + OpAsmParser &parser, SmallVectorImpl &types, + const SmallVectorImpl &handles) { + types.resize(handles.size(), pdl::OperationType::get(parser.getContext())); + return success(); +} + +static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, + ValueRange) {} + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -354,6 +364,33 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ReplicateOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ReplicateOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); + for (const auto &en : llvm::enumerate(getHandles())) { + Value handle = en.value(); + ArrayRef current = state.getPayloadOps(handle); + SmallVector payload; + payload.reserve(numRepetitions * current.size()); + for (unsigned i = 0; i < numRepetitions; ++i) + llvm::append_range(payload, current); + results.set(getReplicated()[en.index()].cast(), payload); + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::ReplicateOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getPattern(), effects); + consumesHandle(getHandles(), effects); + producesHandle(getReplicated(), effects); +} + //===----------------------------------------------------------------------===// // SequenceOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -59,6 +59,22 @@ ip=ip) +class ReplicateOp: + + def __init__(self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None): + super().__init__( + [pdl.OperationType.get()] * len(handles), + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip) + + class SequenceOp: @overload diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -25,3 +25,37 @@ test_print_remark_at_operand %0, "remark" } } + +// ----- + +func.func @func1() { + // expected-note @below {{repeated target op}} + return +} +func.func private @func2() + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func : benefit(1) { + %0 = operands + %1 = types + %2 = operation "func.func"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + pdl.pattern @return : benefit(1) { + %0 = operands + %1 = types + %2 = operation "func.return"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @func in %arg1 + %1 = pdl_match @return in %arg1 + %2 = replicate num(%0) %1 + // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}} + test_consume_operand %2 + test_print_remark_at_operand %0, "remark" + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -569,3 +569,31 @@ transform.test_mixed_sucess_and_silenceable %0 } } + +// ----- + +module { + func.func private @foo() + func.func private @bar() + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @func : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.func"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @func in %arg1 + %1 = replicate num(%0) %arg1 + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_ops %1 + %2 = replicate num(%0) %1 + // expected-remark @below {{4}} + test_print_number_of_associated_payload_ir_ops %2 + } + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -275,6 +275,18 @@ return emitDefaultSilenceableFailure(target); } +DiagnosedSilenceableFailure +mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( + transform::TransformResults &results, transform::TransformState &state) { + emitRemark() << state.getPayloadOps(getHandle()).size(); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getHandle(), effects); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -212,4 +212,13 @@ }]; } +def TestPrintNumberOfAssociatedPayloadIROps + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$handle); + let assemblyFormat = "$handle attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -94,3 +94,19 @@ # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): # CHECK: = merge_handles %[[ARG1]] + + +@run +def testReplicateOp(): + with_pdl = transform.WithPDLPatternsOp() + with InsertionPoint(with_pdl.body): + sequence = transform.SequenceOp(with_pdl.bodyTarget) + with InsertionPoint(sequence.body): + m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") + m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") + transform.ReplicateOp(m1, [m2]) + transform.YieldOp() + # CHECK-LABEL: TEST: testReplicateOp + # CHECK: %[[FIRST:.+]] = pdl_match + # CHECK: %[[SECOND:.+]] = pdl_match + # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]