diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1601,19 +1601,20 @@ // ----- // CHECK-LABEL: func @test_tracked_rewrite() { -// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} -// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"} -// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} +// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1 +// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> i1 +// CHECK-NEXT: return // CHECK-NEXT: } func.func @test_tracked_rewrite() { - %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) - %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1) - %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) + %0 = transform.test_dummy_payload_op {replace_me} : () -> (i1) + %1 = transform.test_dummy_payload_op {erase_me} : () -> (i1) + %2 = transform.test_dummy_payload_op {replace_me} : () -> (i1) + func.return } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-remark @below {{2 iterations}} transform.test_tracked_rewrite %0 : (!transform.any_op) -> () // One replacement op (test.drop_mapping) is dropped from the mapping. diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -687,33 +687,16 @@ transform::modifiesPayload(effects); } -namespace { -/// A TrackingListener for test cases. When the replacement op is -/// "test.update_mapping", it is considered as a replacement op in the transform -/// state mapping. Otherwise, it is not and the original op is simply removed -/// from the mapping. -class TestTrackingListener : public transform::TrackingListener { - using transform::TrackingListener::TrackingListener; - -protected: - FailureOr - findReplacementOp(Operation *op, ValueRange newValues) const override { - if (newValues.size() != 1) - return failure(); - Operation *replacement = newValues[0].getDefiningOp(); - if (!replacement) - return failure(); - if (replacement->getName().getStringRef() != "test.update_mapping") - return failure(); - return replacement; - } -}; -} // namespace +void mlir::test::TestDummyPayloadOp::getEffects( + SmallVectorImpl &effects) { + for (OpResult result : getResults()) + transform::producesHandle(result, effects); +} DiagnosedSilenceableFailure mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, transform::TransformState &state) { - TestTrackingListener listener(state, *this); + transform::ErrorCheckingTrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); int64_t numIterations = 0; @@ -721,19 +704,23 @@ // loop body. Replacement ops are not enumerated. for (Operation *op : state.getPayloadOps(getIn())) { ++numIterations; - rewriter.setInsertionPointToEnd(op->getBlock()); + (void)op; // Erase all payload ops. The outer loop should have only one iteration. for (Operation *op : state.getPayloadOps(getIn())) { - if (op->getName().getStringRef() != "test.replace_me") + rewriter.setInsertionPoint(op); + if (op->hasAttr("erase_me")) { + rewriter.eraseOp(op); continue; - auto replacementName = op->getAttrOfType("replacement"); - if (!replacementName) + } + if (!op->hasAttr("replace_me")) { continue; + } + SmallVector attributes; - attributes.emplace_back(rewriter.getStringAttr("original_op"), - op->getName().getIdentifier()); - OperationState opState(op->getLoc(), replacementName, + attributes.emplace_back(rewriter.getStringAttr("new_op"), + rewriter.getUnitAttr()); + OperationState opState(op->getLoc(), op->getName().getIdentifier(), /*operands=*/ValueRange(), /*types=*/op->getResultTypes(), attributes); Operation *newOp = rewriter.create(opState); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -467,6 +467,29 @@ let cppNamespace = "::mlir::test"; } +// This op is used as a payload op. It must be a registered op, so that it can +// be created with "RewriterBase::replaceOpWithNewOp" (needed for a test case). +// Since only TransformOpInterface can be injected into the transform dialect, +// this op implements the interface, even though it is not used as a transform +// op. +def TestDummyPayloadOp + : Op, + TransformOpInterface]> { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$outs); + let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + + let extraClassDeclaration = [{ + DiagnosedSilenceableFailure apply(transform::TransformResults &results, + transform::TransformState &state) { + llvm_unreachable("op should not be used as a transform"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + }]; +} + def TestTrackedRewriteOp : Op,