diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -95,6 +95,46 @@ let hasVerifier = 1; } +def ForeachOp : TransformDialectOp<"foreach", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> + ]> { + let summary = "Executes the body for each payload op"; + let description = [{ + This op has exactly one region with exactly one block ("body"). The body is + executed for each payload op that is associated to the target operand in an + unbatched fashion. I.e., the block argument ("iteration variable") is always + mapped to exactly one payload op. + + This op always reads the target handle. Furthermore, it consumes the handle + if there is a transform op in the body that consumes the iteration variable. + This op does not return anything. + + The transformations inside the body are applied in order of their + appearance. During application, if any transformation in the sequence fails, + the entire sequence fails immediately leaving the payload IR in potentially + invalid state, i.e., this operation offers no transformation rollback + capabilities. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$target $body attr-dict"; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + + BlockArgument getIterationVariable() { + return getBody().front().getArgument(0); + } + }]; +} + def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -273,6 +273,64 @@ return success(); } +//===----------------------------------------------------------------------===// +// ForeachOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ForeachOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (Operation *op : payloadOps) { + auto scope = state.make_region_scope(getBody()); + if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) + return DiagnosedSilenceableFailure::definiteFailure(); + + for (Operation &transform : getBody().front().without_terminator()) { + DiagnosedSilenceableFailure result = state.applyTransform( + cast(transform)); + if (!result.succeeded()) + return result; + } + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::ForeachOp::getEffects( + SmallVectorImpl &effects) { + BlockArgument iterVar = getIterationVariable(); + if (any_of(getBody().front().without_terminator(), [&](Operation &op) { + return isHandleConsumed(iterVar, cast(&op)); + })) { + consumesHandle(getTarget(), effects); + } else { + onlyReadsHandle(getTarget(), effects); + } +} + +void transform::ForeachOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + Region *bodyRegion = &getBody(); + if (!index) { + regions.emplace_back(bodyRegion, bodyRegion->getArguments()); + return; + } + + // Branch back to the region or the parent. + assert(*index == 0 && "unexpected region index"); + regions.emplace_back(bodyRegion, bodyRegion->getArguments()); + regions.emplace_back(); +} + +OperandRange +transform::ForeachOp::getSuccessorEntryOperands(Optional index) { + // The iteration variable op handle is mapped to a subset (one op to be + // precise) of the payload ops of the ForeachOp operand. + assert(index && *index == 0 && "unexpected region index"); + return getOperation()->getOperands(); +} + //===----------------------------------------------------------------------===// // GetClosestIsolatedParentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -184,3 +184,18 @@ ^bb0: transform.yield } + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{result #0 has more than one potential consumer}} + %0 = test_produce_param_or_forward_operand 42 + // expected-note @below {{used here as operand #0}} + transform.foreach %0 { + ^bb1(%arg1: !pdl.operation): + transform.test_consume_operand %arg1 + } + // expected-note @below {{used here as operand #0}} + transform.test_consume_operand %0 +} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -49,3 +49,12 @@ ^bb3(%arg3: !pdl.operation): } } + +// CHECK: transform.sequence +// CHECK: foreach +transform.sequence { +^bb0(%arg0: !pdl.operation): + transform.foreach %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -597,3 +597,31 @@ } } } + +// ----- + +func.func @bar() { + // expected-remark @below {{transform applied}} + %0 = arith.constant 0 : i32 + // expected-remark @below {{transform applied}} + %1 = arith.constant 1 : i32 + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @const : benefit(1) { + %r = pdl.types + %0 = pdl.operation "arith.constant" -> (%r : !pdl.range) + pdl.rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %f = pdl_match @const in %arg1 + transform.foreach %f { + ^bb2(%arg2: !pdl.operation): + transform.test_print_remark_at_operand %arg2, "transform applied" + } + } +}