diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" namespace mlir { class TilingInterface; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" +include "mlir/IR/RegionKindInterface.td" def DecomposeOp : Op, + DeclareOpInterfaceMethods] # GraphRegionNoTerminator.traits> { + let description = [{ + Replace all `target` payload ops with the single op that is contained in + this op's region. All targets must have zero arguments and must be isolated + from above. + + This op is for debugging/experiments only. + + #### Return modes + + This operation consumes the `target` handle. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$replacement); + let regions = (region SizedRegion<1>:$bodyRegion); + let assemblyFormat = "$target attr-dict-with-keyword regions"; + let hasVerifier = 1; +} + def ScalarizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" @@ -883,6 +884,64 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// ReplaceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ReplaceOp::apply(TransformResults &transformResults, + TransformState &state) { + ArrayRef payload = state.getPayloadOps(getTarget()); + + // Check for invalid targets. + for (Operation *target : payload) { + if (target->getNumOperands() > 0) + return emitDefiniteFailure() << "expected target without operands"; + if (!target->hasTrait() && target->getNumRegions() > 0) + return emitDefiniteFailure() + << "expected target that is isloated from above"; + } + + // Clone and replace. + IRRewriter rewriter(getContext()); + Operation *pattern = &getBodyRegion().front().front(); + SmallVector replacements; + for (Operation *target : payload) { + if (getOperation()->isAncestor(target)) + continue; + rewriter.setInsertionPoint(target); + Operation *replacement = rewriter.clone(*pattern); + rewriter.replaceOp(target, replacement->getResults()); + replacements.push_back(replacement); + } + transformResults.set(getReplacement().cast(), replacements); + return DiagnosedSilenceableFailure(success()); +} + +void transform::ReplaceOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getReplacement(), effects); + modifiesPayload(effects); +} + +LogicalResult transform::ReplaceOp::verify() { + if (!getBodyRegion().hasOneBlock()) + return emitOpError() << "expected one block"; + if (std::distance(getBodyRegion().front().begin(), + getBodyRegion().front().end()) != 1) + return emitOpError() << "expected one operation in block"; + Operation *replacement = &getBodyRegion().front().front(); + if (replacement->getNumOperands() > 0) + return replacement->emitOpError() + << "expected replacement without operands"; + if (!replacement->hasTrait() && + replacement->getNumRegions() > 0) + return replacement->emitOpError() + << "expect op that is isolated from above"; + return success(); +} + //===----------------------------------------------------------------------===// // ScalarizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter %s -allow-unregistered-dialect -verify-diagnostics --split-input-file | FileCheck %s + +// CHECK: func.func @foo() { +// CHECK: "dummy_op"() : () -> () +// CHECK: } +// CHECK-NOT: func.func @bar +func.func @bar() { + "another_op"() : () -> () +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + transform.structured.replace %0 { + func.func @foo() { + "dummy_op"() : () -> () + } + } +} + +// ----- + +func.func @bar(%arg0: i1) { + "another_op"(%arg0) : (i1) -> () +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["another_op"]} in %arg1 + // expected-error @+1 {{expected target without operands}} + transform.structured.replace %0 { + "dummy_op"() : () -> () + } +} + +// ----- + +func.func @bar() { + "another_op"() : () -> () +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["another_op"]} in %arg1 + transform.structured.replace %0 { + ^bb0(%a: i1): + // expected-error @+1 {{expected replacement without operands}} + "dummy_op"(%a) : (i1) -> () + } +}