diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -118,12 +118,17 @@ the entire sequence fails immediately leaving the payload IR in potentially invalid state, i.e., this operation offers no transformation rollback capabilities. + + This op generates as many handles as the terminating YieldOp has operands. + For each result, the payload ops of the corresponding YieldOp operand are + merged and mapped to the same resulting handle. }]; let arguments = (ins PDL_Operation:$target); - let results = (outs); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "$target $body attr-dict"; + let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict"; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. @@ -132,6 +137,8 @@ BlockArgument getIterationVariable() { return getBody().front().getArgument(0); } + + transform::YieldOp getYieldOp(); }]; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -281,18 +281,32 @@ transform::ForeachOp::apply(transform::TransformResults &results, transform::TransformState &state) { ArrayRef payloadOps = state.getPayloadOps(getTarget()); + SmallVector> resultOps(getNumResults(), {}); + for (Operation *op : payloadOps) { auto scope = state.make_region_scope(getBody()); if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) return DiagnosedSilenceableFailure::definiteFailure(); + // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { DiagnosedSilenceableFailure result = state.applyTransform( cast(transform)); if (!result.succeeded()) return result; } + + // Append yielded payload ops to result list (if any). + for (unsigned i = 0; i < getNumResults(); ++i) { + ArrayRef yieldedOps = + state.getPayloadOps(getYieldOp().getOperand(i)); + resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); + } } + + for (unsigned i = 0; i < getNumResults(); ++i) + results.set(getResult(i).cast(), resultOps[i]); + return DiagnosedSilenceableFailure::success(); } @@ -306,6 +320,9 @@ } else { onlyReadsHandle(getTarget(), effects); } + + for (Value result : getResults()) + producesHandle(result, effects); } void transform::ForeachOp::getSuccessorRegions( @@ -331,6 +348,21 @@ return getOperation()->getOperands(); } +transform::YieldOp transform::ForeachOp::getYieldOp() { + return cast(getBody().front().getTerminator()); +} + +LogicalResult transform::ForeachOp::verify() { + auto yieldOp = getYieldOp(); + if (getNumResults() != yieldOp.getNumOperands()) + return emitOpError() << "expects the same number of results as the " + "terminator has operands"; + for (Value v : yieldOp.getOperands()) + if (!v.getType().isa()) + return yieldOp->emitOpError("expects only PDL_Operation operands"); + return success(); +} + //===----------------------------------------------------------------------===// // GetClosestIsolatedParentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -627,3 +627,52 @@ } } } + +// ----- + +func.func @bar() { + scf.execute_region { + // expected-remark @below {{transform applied}} + %0 = arith.constant 0 : i32 + scf.yield + } + + scf.execute_region { + // expected-remark @below {{transform applied}} + %1 = arith.constant 1 : i32 + // expected-remark @below {{transform applied}} + %2 = arith.constant 2 : i32 + scf.yield + } + + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @const : benefit(1) { + %r = pdl.types + %0 = pdl.operation "arith.constant" -> (%r : !pdl.range) + pdl.rewrite %0 with "transform.dialect" + } + + pdl.pattern @execute_region : benefit(1) { + %r = pdl.types + %0 = pdl.operation "scf.execute_region" -> (%r : !pdl.range) + pdl.rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %f = pdl_match @execute_region in %arg1 + %results = transform.foreach %f -> !pdl.operation { + ^bb2(%arg2: !pdl.operation): + %g = transform.pdl_match @const in %arg2 + transform.yield %g : !pdl.operation + } + + // expected-remark @below {{3}} + transform.test_print_number_of_associated_payload_ir_ops %results + transform.test_print_remark_at_operand %results, "transform applied" + } +}