diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,11 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/OwningOpRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" @@ -830,25 +828,6 @@ } #endif // NDEBUG - // If the transform dialect may use PDL which may modify the IR, clone it - // before use to avoid concurrent modification in case this is being called - // from pass instances running concurrently with a shared transform script. - auto *pdlDialect = - transform->getContext()->getLoadedDialect(); - bool hasPDL = transform - .walk([pdlDialect](Operation *op) { - if (op->getDialect() == pdlDialect) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted(); - - OwningOpRef owningCopy; - if (hasPDL) { - owningCopy = OwningOpRef(transform->clone()); - transform = owningCopy.get(); - } - TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, options); return state.applyTransform(transform).checkAndReport(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -89,9 +88,11 @@ if (!patternOp) return failure(); + // Copy the pattern operation into a new module that is compiled and + // consumed by the PDL interpreter. OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); - patternOp->moveBefore(pdlModuleOp->getBody(), - pdlModuleOp->getBody()->end()); + auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); + builder.clone(*patternOp); PDLPatternModule patternModule(std::move(pdlModuleOp)); // Merge in the hooks owned by the dialect. Make a copy as they may be @@ -964,8 +965,6 @@ DiagnosedSilenceableFailure transform::WithPDLPatternsOp::apply(transform::TransformResults &results, transform::TransformState &state) { - OwningOpRef pdlModuleOp = - ModuleOp::create(getOperation()->getLoc()); TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { if (!isa(nested)) { diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-concurrent-source.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-concurrent-source.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-concurrent-source.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s +// No need to check anything else than parsing here, this is being used by another test as data. + +transform.with_pdl_patterns { +^bb0(%arg0: !transform.any_op): + pdl.pattern @func_return : benefit(1) { + %0 = pdl.operation "func.return" + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 : !transform.any_op failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return"> + test_print_remark_at_operand %0, "matched" : !transform.op<"func.return"> + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-concurrent.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-concurrent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-concurrent.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-concurrent-source.mlir}))" \ +// RUN: --verify-diagnostics + +// Exercising the pass on multiple functions of different lengths that may be +// processed concurrently. This should expose potential races. + +func.func @f1() { + // expected-remark @below {{matched}} + return +} + +func.func @f2() { + // expected-remark @below {{matched}} + return +} + +func.func @f3() { + call @f2() : () -> () + call @f2() : () -> () + call @f5() : () -> () + call @f7() : () -> () + call @f5() : () -> () + call @f5() : () -> () + // expected-remark @below {{matched}} + return +} + +func.func @f4() { + call @f3() : () -> () + call @f3() : () -> () + // expected-remark @below {{matched}} + return +} + +func.func @f5() { + call @f7() : () -> () + call @f7() : () -> () + call @f7() : () -> () + call @f7() : () -> () + call @f1() : () -> () + call @f1() : () -> () + call @f7() : () -> () + call @f7() : () -> () + call @f7() : () -> () + call @f7() : () -> () + // expected-remark @below {{matched}} + return +} + +func.func @f6() { + // expected-remark @below {{matched}} + return +} + +func.func @f7() { + // expected-remark @below {{matched}} + return +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -24,12 +24,11 @@ /// module. template -class ModulePassWrapper : public PassWrapper> { -}; +class OpPassWrapper : public PassWrapper> {}; class TestTransformDialectInterpreterPass : public transform::TransformInterpreterPassBase< - TestTransformDialectInterpreterPass, ModulePassWrapper> { + TestTransformDialectInterpreterPass, OpPassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectInterpreterPass)