diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h @@ -18,21 +18,9 @@ class DialectRegistry; namespace tensor { - -/// A specialized TrackingListener for transform ops that operate on tensor IR. -/// This listener skips cast-like tensor ops when looking for payload op -/// replacements. -class TrackingListener : public transform::TrackingListener { -public: - using transform::TrackingListener::TrackingListener; - -protected: - Operation *findReplacementOp(Operation *op, - ValueRange newValues) const override; -}; - void registerTransformDialectExtension(DialectRegistry ®istry); - +void registerFindPayloadReplacementOpInterfaceExternalModels( + DialectRegistry ®istry); } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -193,4 +193,39 @@ let cppNamespace = "::mlir::transform"; } +def FindPayloadReplacementOpInterface + : OpInterface<"FindPayloadReplacementOpInterface"> { + let description = [{ + This interface is queried by the `TrackingListener` and can be implemented + by payload ops to indicate that the lookup should be continue with its + operands when looking for payload op replacements. + + Example: Consider the case where a tracked "test.foo" payload op is replaced + with a new "test.foo" op, but wrapped in a "tensor.reshape" op. In that + case, the mapping of the original "test.foo" op should be updated with the + new "test.foo" op. A "tensor.reshape" is a metadata-only op that should be + skipped when inspecting the replacement values of the original "test.foo" + op. More details can be found at `TrackingListener` documentation. + + Note: Ops that implement `CastOpInterface` do not need to implement this + interface. Such ops are skipped by default. This interface should be + implemented by cast-like/metadata-only ops that cannot implement + `CastOpInterface`. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the operands at which the lookup for replacement payload ops + should continue. + }], + /*returnType=*/"::llvm::SmallVector<::mlir::Value>", + /*name=*/"getNextOperands", + /*arguments=*/(ins) + >, + ]; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -50,6 +50,45 @@ /// replaced with the given values. By default, if all values are defined by /// the same op, which also has the same type as the given op, that defining /// op is used as a replacement. + /// + /// Example: A tracked "linalg.generic" with two results is replaced with two + /// values defined by (another) "linalg.generic". It is reasonable to assume + /// that the replacement "linalg.generic" represents the same "computation". + /// Therefore, the payload op mapping is updated to the defining op of the + /// replacement values. + /// + /// Counter Example: A "linalg.generic" is replaced with values defined by an + /// "scf.for". Without further investigation, the relationship between the + /// "linalg.generic" and the "scf.for" is unclear. They may not represent the + /// same computation; e.g., there may be tiled "linalg.generic" inside the + /// loop body that represents the original computation. Therefore, the + /// TrackingListener is conservative by default: it drops the mapping and + /// triggers the "payload replacement not found" notification. + /// + /// If no replacement op could be found according to the rules mentioned + /// above, this function tries to skip over cast-like ops that implement + /// `CastOpInterface`. + /// + /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", + /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is + /// reasonable to assume that the wrapped "linalg.generic" represents the same + /// computation as the original "linalg.generic". The mapping is updated + /// accordingly. + /// + /// Certain ops (typically also metadata-only ops) are not considered casts, + /// but should be skipped nonetheless. Such ops should implement + /// `FindPayloadReplacementOpInterface` to specify with which operands the + /// lookup should continue. + /// + /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", + /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but + /// not cast. (Implementing `CastOpInterface` would be incorrect and cause + /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface` + /// implementation, the replacement op lookup continues with the wrapped + /// "linalg.generic" and the mapping is updated accordingly. + /// + /// Derived classes may override `findReplacementOp` to specify custom + /// replacement rules. virtual Operation *findReplacementOp(Operation *op, ValueRange newValues) const; diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -159,6 +159,7 @@ shape::registerBufferizableOpInterfaceExternalModels(registry); sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); tensor::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -15,50 +15,68 @@ #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace tensor; //===----------------------------------------------------------------------===// -// TrackingListener +// FindPayloadReplacementOpInterface implementations //===----------------------------------------------------------------------===// -Operation * -tensor::TrackingListener::findReplacementOp(Operation *op, - ValueRange newValues) const { - SmallVector values(newValues.begin(), newValues.end()); - do { - if (Operation *replacement = - transform::TrackingListener::findReplacementOp(op, values)) - return replacement; - - Operation *defOp = getCommonDefiningOp(values); - if (!defOp) - return nullptr; - - // Skip cast-like operations. - values.clear(); - llvm::TypeSwitch(defOp) - .Case([&](CastOp op) { values.push_back(op.getSource()); }) - .Case( - [&](CollapseShapeOp op) { values.push_back(op.getSrc()); }) - .Case( - [&](ExpandShapeOp op) { values.push_back(op.getSrc()); }) - .Case( - [&](ReshapeOp op) { values.push_back(op.getSource()); }) - .Case([&](InsertSliceOp op) { - if (isCastLikeInsertSliceOp(op)) - values.push_back(op.getSource()); - }) - .Case([&](ExtractSliceOp op) { - if (isCastLikeExtractSliceOp(op)) - values.push_back(op.getSource()); - }) - .Default([](Operation *op) {}); - } while (!values.empty()); - - return nullptr; +namespace { +struct ExtractSliceOpReplacementInterface + : public transform::FindPayloadReplacementOpInterface::ExternalModel< + ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> { + SmallVector getNextOperands(Operation *op) const { + auto extractSliceOp = cast(op); + if (!isCastLikeExtractSliceOp(extractSliceOp)) + return {}; + return {extractSliceOp.getSource()}; + } +}; + +struct InsertSliceOpReplacementInterface + : public transform::FindPayloadReplacementOpInterface::ExternalModel< + InsertSliceOpReplacementInterface, tensor::InsertSliceOp> { + SmallVector getNextOperands(Operation *op) const { + auto insertSliceOp = cast(op); + if (!isCastLikeInsertSliceOp(insertSliceOp)) + return {}; + return {insertSliceOp.getSource()}; + } +}; + +struct ReshapeOpReplacementInterface + : public transform::FindPayloadReplacementOpInterface::ExternalModel< + ReshapeOpReplacementInterface, tensor::ReshapeOp> { + SmallVector getNextOperands(Operation *op) const { + auto reshapeOp = cast(op); + return {reshapeOp.getSource()}; + } +}; + +template +struct ReassociativeReshapeOpReplacementInterface + : public transform::FindPayloadReplacementOpInterface::ExternalModel< + ReassociativeReshapeOpReplacementInterface, ConcreteOp> { + SmallVector getNextOperands(Operation *op) const { + auto reshapeOp = cast(op); + return {reshapeOp.getSrc()}; + } +}; +} // namespace + +void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + CollapseShapeOp::attachInterface< + ReassociativeReshapeOpReplacementInterface>(*ctx); + ExpandShapeOp::attachInterface< + ReassociativeReshapeOpReplacementInterface>(*ctx); + ExtractSliceOp::attachInterface(*ctx); + InsertSliceOp::attachInterface(*ctx); + ReshapeOp::attachInterface(*ctx); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -74,17 +74,35 @@ ValueRange newValues) const { assert(op->getNumResults() == newValues.size() && "invalid number of replacement values"); + SmallVector values(newValues.begin(), newValues.end()); - // If the replacement values belong to different ops, drop the mapping. - Operation *defOp = getCommonDefiningOp(newValues); - if (!defOp) - return nullptr; + do { + // If the replacement values belong to different ops, drop the mapping. + Operation *defOp = getCommonDefiningOp(values); + if (!defOp) + return nullptr; - // If the replacement op has a different type, drop the mapping. - if (op->getName() != defOp->getName()) - return nullptr; + // If the defining op has the same type, we take it as a replacement. + if (op->getName() == defOp->getName()) + return defOp; - return defOp; + values.clear(); + + // Skip through ops that implement FindPayloadReplacementOpInterface. + if (auto findReplacementOpInterface = + dyn_cast(defOp)) { + values.assign(findReplacementOpInterface.getNextOperands()); + continue; + } + + // Skip through ops that implement CastOpInterface. + if (isa(defOp)) { + values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); + continue; + } + } while (!values.empty()); + + return nullptr; } LogicalResult transform::TrackingListener::notifyMatchFailure( diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -10,7 +10,6 @@ MLIRPass MLIRSCFDialect MLIRTensorDialect - MLIRTensorTransformOps MLIRTensorTransforms MLIRTransformDialect MLIRTransforms diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -14,10 +14,10 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -296,9 +296,9 @@ } namespace { -class DummyTrackingListener : public tensor::TrackingListener { +class DummyTrackingListener : public transform::TrackingListener { public: - using tensor::TrackingListener::TrackingListener; + using transform::TrackingListener::TrackingListener; // Expose `findReplacementOp` as a public function, so that it can be tested. Operation *getReplacementOp(Operation *op, ValueRange newValues) const { diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -857,7 +857,6 @@ "//mlir:Pass", "//mlir:SCFDialect", "//mlir:TensorDialect", - "//mlir:TensorTransformOps", "//mlir:TensorTransforms", "//mlir:TransformDialect", "//mlir:Transforms",