diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -989,7 +989,7 @@ return indexAttr.cast().getInt(); }))); return reassociationIndices; - }; + } RankedTensorType getSrcType() { return getSrc().getType().cast(); } diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h @@ -0,0 +1,36 @@ +//===- TensorTransformOps.h - Tensor transformation ops ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_TRANSFORMOPS_TENSORTRANSFORMOPS_H +#define MLIR_DIALECT_TENSOR_TRANSFORMOPS_TENSORTRANSFORMOPS_H + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tensor { + +/// A specialized TrackingListener for transform ops that operate on tensor IR. +/// This listener skips cast-like tensor ops when looking for payload op +/// replacements. +class TrackingListener : public transform::TrackingListener { +public: + explicit TrackingListener(transform::TransformState &state) + : transform::TrackingListener(state) {} + +protected: + Operation *findReplacementOp(Operation *op, + ValueRange newValues) const override; +}; + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_TRANSFORMOPS_TENSORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -15,6 +15,7 @@ #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" @@ -31,6 +32,41 @@ using SequenceBodyBuilderArgsFn = ::llvm::function_ref; + +/// A listener that updates a TransformState based on IR modifications. This +/// listener can be used during a greedy pattern rewrite to keep the transform +/// state up-to-date. +class TrackingListener : public RewriterBase::Listener, + public TransformState::Extension { +public: + explicit TrackingListener(TransformState &state) + : TransformState::Extension(state) {} + +protected: + /// Return a replacement payload op for the given op, which is going to be + /// replaced with the given values. By default, if all values are defined by + /// the same newly-created op, which also has the same type as the given op, + /// that defining op is used as a replacement. + virtual Operation *findReplacementOp(Operation *op, + ValueRange newValues) const; + + /// Return "true" if the given op is a new op. + bool isNewOp(Operation *op) const; + + /// Return the single op that defines all given values (if any). + static Operation *getCommonDefiningOp(ValueRange values); + +private: + void notifyOperationInserted(Operation *op) override; + + void notifyOperationRemoved(Operation *op) override; + + void notifyOperationReplaced(Operation *op, ValueRange newValues) override; + + /// Ops that were newly created during the transform. + DenseMap> newOps; +}; + } // namespace transform } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/CMakeLists.txt b/mlir/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRTensorTransformOps + TensorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/TransformOps + + LINK_LIBS PUBLIC + MLIRIR + MLIRTensorDialect + MLIRTransformDialect +) diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -0,0 +1,51 @@ +//===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace tensor; + +//===----------------------------------------------------------------------===// +// TrackingListener +//===----------------------------------------------------------------------===// + +Operation * +tensor::TrackingListener::findReplacementOp(Operation *op, + ValueRange newValues) const { + SmallVector values(newValues.begin(), newValues.end()); + do { + if (Operation *replacement = + transform::TrackingListener::findReplacementOp(op, values)) + return replacement; + + Operation *defOp = getCommonDefiningOp(values); + if (!defOp) + return nullptr; + + // Skip cast-like operations. + // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp + // implement that interface + values.clear(); + llvm::TypeSwitch(defOp) + .Case([&](CastOp op) { values.push_back(op.getSource()); }) + .Case( + [&](CollapseShapeOp op) { values.push_back(op.getSrc()); }) + .Case( + [&](ExpandShapeOp op) { values.push_back(op.getSrc()); }) + .Case( + [&](ReshapeOp op) { values.push_back(op.getSource()); }) + .Default([](Operation *op) {}); + } while (!values.empty()); + + return nullptr; +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -132,6 +132,89 @@ } } // namespace +//===----------------------------------------------------------------------===// +// TrackingListener +//===----------------------------------------------------------------------===// + +Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { + Operation *defOp = nullptr; + for (Value v : values) { + // Skip empty values. + if (!v) + continue; + if (!defOp) { + defOp = v.getDefiningOp(); + continue; + } + if (defOp != v.getDefiningOp()) + return nullptr; + } + return defOp; +} + +Operation * +transform::TrackingListener::findReplacementOp(Operation *op, + ValueRange newValues) const { + assert(op->getNumResults() == newValues.size() && + "invalid number of replacement values"); + + // If the replacement values belong to different ops, drop the mapping. + Operation *defOp = getCommonDefiningOp(newValues); + if (!defOp) + return nullptr; + + // If the replacement op has a different type, drop the mapping. + if (op->getName() != defOp->getName()) + return nullptr; + + // If the replacement op is not a new op, drop the mapping. + if (!isNewOp(defOp)) + return nullptr; + + return defOp; +} + +bool transform::TrackingListener::isNewOp(Operation *op) const { + auto it = newOps.find(op->getName()); + if (it == newOps.end()) + return false; + return it->second.contains(op); +} + +void transform::TrackingListener::notifyOperationInserted(Operation *op) { + newOps[op->getName()].insert(op); +} + +void transform::TrackingListener::notifyOperationRemoved(Operation *op) { + // TODO: Walk can be removed when D144193 has landed. + op->walk([&](Operation *op) { + // Keep set of new ops up-to-date. + auto it = newOps.find(op->getName()); + if (it != newOps.end()) + it->second.erase(op); + // Remove mappings for result values. + for (OpResult value : op->getResults()) + (void)replacePayloadValue(value, nullptr); + // Remove mapping for op. + (void)replacePayloadOp(op, nullptr); + }); +} + +void transform::TrackingListener::notifyOperationReplaced( + Operation *op, ValueRange newValues) { + assert(op->getNumResults() == newValues.size() && + "invalid number of replacement values"); + if (op->getNumResults() == 0) + return; + + // Replace value handles. + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) + (void)replacePayloadValue(oldValue, newValue); + + // Replace op handle. + (void)replacePayloadOp(op, findReplacementOp(op, newValues)); +} + //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5667,6 +5667,19 @@ ], ) +cc_library( + name = "TensorTransformOps", + srcs = glob(["lib/Dialect/Tensor/TransformOps/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Tensor/TransformOps/*.h"]), + includes = ["include"], + deps = [ + ":IR", + ":TensorDialect", + ":TransformDialect", + "//llvm:Support", + ], +) + cc_library( name = "Rewrite", srcs = glob([ @@ -7142,6 +7155,7 @@ ":TensorDialect", ":TensorInferTypeOpInterfaceImpl", ":TensorTilingInterfaceImpl", + ":TensorTransformOps", ":TensorTransforms", ":TosaDialect", ":TosaToLinalg",