diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -59,6 +59,11 @@ /// the operands to `block`'s terminator. void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results); + +/// Make a dummy transform state for testing purposes. This MUST NOT be used +/// outside of test cases. +TransformState makeTransformStateForTesting(Region *region, + Operation *payloadRoot); } // namespace detail /// Options controlling the application of transform operations by the @@ -162,6 +167,9 @@ const RaggedArray &, const TransformOptions &); + friend TransformState + detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot); + public: /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -23,6 +23,20 @@ // TrackingListener //===----------------------------------------------------------------------===// +/// A tensor.insert_slice is a cast-like operation if it the source tensor and +/// the destination tensor have the same number of elements. I.e., the result +/// tensor data equals the source tensor data, maybe rank-extended to a +/// different shape. +static bool isCastLikeInsertSliceOp(InsertSliceOp op) { + // TODO: Support dynamically shaped tensors. Utilize ValueBoundsOpInterface + // to check if source and destination have the same shape. + if (!op.getSourceType().hasStaticShape() || + !op.getDestType().hasStaticShape()) + return false; + return op.getSourceType().getNumElements() == + op.getDestType().getNumElements(); +} + Operation * tensor::TrackingListener::findReplacementOp(Operation *op, ValueRange newValues) const { @@ -48,6 +62,10 @@ [&](ExpandShapeOp op) { values.push_back(op.getSrc()); }) .Case( [&](ReshapeOp op) { values.push_back(op.getSource()); }) + .Case([&](InsertSliceOp op) { + if (isCastLikeInsertSliceOp(op)) + values.push_back(op.getSource()); + }) .Default([](Operation *op) {}); } while (!values.empty()); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1238,6 +1238,12 @@ } } +transform::TransformState +transform::detail::makeTransformStateForTesting(Region *region, + Operation *payloadRoot) { + return TransformState(region, payloadRoot); +} + //===----------------------------------------------------------------------===// // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt -test-tensor-transform-patterns=test-tracking-listener \ +// RUN: -split-input-file -verify-diagnostics %s + +func.func @replace_op_with_op_of_same_type() { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>) + return +} + +// ----- + +func.func @replace_op_with_op_of_different_type() { + // expected-error @below {{listener could not find replacement op}} + %0 = tensor.empty() {replaced} : tensor<5xf32> + %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>) + return +} + +// ----- + +func.func @multi_result_replacement() { + %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>) + // expected-remark @below {{replacement found}} + %1:2 = "test.foo"() {replacement_0 = 0, replacement_1 = 1} + : () -> (tensor<5xf32>, tensor<6xf32>) + return +} + +// ----- + +func.func @multi_result_replacement_with_multiple_ops() { + // expected-error @below {{listener could not find replacement op}} + %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>) + %1:2 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>, tensor<6xf32>) + %2:2 = "test.foo"() {replacement_1 = 1} : () -> (tensor<5xf32>, tensor<6xf32>) + return +} + +// ----- + +func.func @replacement_wrapped_in_cast() { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor) + %2 = tensor.cast %1 {replacement_0 = 0} : tensor to tensor<5xf32> + return +} + +// ----- + +func.func @replacement_wrapped_in_chain_of_casts() { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor) + %2 = tensor.cast %1 : tensor to tensor<5xf32> + %3 = tensor.cast %2 : tensor<5xf32> to tensor + %4 = tensor.cast %3 {replacement_0 = 0} : tensor to tensor<5xf32> + return +} + +// ----- + +func.func @cast_like_insert_slice(%t: tensor<1x5xf32>) { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor<5xf32>) + %2 = tensor.insert_slice %1 into %t[0, 0][1, 5][1, 1] {replacement_0 = 0} + : tensor<5xf32> into tensor<1x5xf32> + return +} + +// ----- + +func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) { + // expected-error @below {{listener could not find replacement op}} + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + %1 = "test.foo"() : () -> (tensor<5xf32>) + // This is not a cast-like insert_slice op because elements from %t are + // contained in %2. + %2 = tensor.insert_slice %1 into %t[0][5][1] {replacement_0 = 0} + : tensor<5xf32> into tensor<7xf32> + return +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -10,6 +10,8 @@ MLIRPass MLIRSCFDialect MLIRTensorDialect + MLIRTensorTransformOps MLIRTensorTransforms + MLIRTransformDialect MLIRTransforms ) diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -14,8 +14,10 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -85,6 +87,11 @@ *this, "test-simplify-pack-patterns", llvm::cl::desc("Test patterns to simplify tensor.pack"), llvm::cl::init(false)}; + + Option testTrackingListener{ + *this, "test-tracking-listener", + llvm::cl::desc("Test tensor TrackingListener for the transform dialect"), + llvm::cl::init(false)}; }; } // namespace @@ -276,6 +283,82 @@ return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +namespace { +class DummyTrackingListener : public tensor::TrackingListener { +public: + using tensor::TrackingListener::TrackingListener; + + // Expose `findReplacementOp` as a public function, so that it can be tested. + Operation *getReplacementOp(Operation *op, ValueRange newValues) const { + return findReplacementOp(op, newValues); + } +}; +} // namespace + +static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { + // Find replaced op. + Operation *replaced = nullptr; + WalkResult status = rootOp->walk([&](Operation *op) { + if (op->hasAttr("replaced")) { + if (replaced) { + op->emitError("only one 'replaced' op is allowed per test case"); + replaced->emitRemark("other 'replaced' op"); + return WalkResult::interrupt(); + } + replaced = op; + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) + return failure(); + if (!replaced) { + replaced->emitError("could not find 'replaced' op"); + return failure(); + } + + // Find replacements. + SmallVector replacements(replaced->getNumResults(), Value()); + status = rootOp->walk([&](Operation *op) { + for (int64_t i = 0; i < replaced->getNumResults(); ++i) { + if (auto attr = op->getAttrOfType("replacement_" + + std::to_string(i))) { + if (replacements[i]) { + op->emitError("only one 'replacement_" + std::to_string(i) + + "' is allowed per test case"); + replacements[i].getDefiningOp()->emitRemark("other 'replacement_" + + std::to_string(i) + "'"); + return WalkResult::interrupt(); + } + replacements[i] = op->getResult(attr.getInt()); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) + return failure(); + + if (!llvm::all_of(replacements, + [](Value v) { return static_cast(v); })) { + replaced->emitError("insufficient replacement values"); + return failure(); + } + + // Find the replacement op (if any) and emit a remark/error. + transform::TransformState transformState = + transform::detail::makeTransformStateForTesting(/*region=*/nullptr, + /*payloadRoot=*/nullptr); + DummyTrackingListener listener(transformState, + transform::TransformOpInterface()); + Operation *replacement = listener.getReplacementOp(replaced, replacements); + if (!replacement) { + replaced->emitError("listener could not find replacement op"); + return failure(); + } + + replacement->emitRemark("replacement found"); + return success(); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSimplifyPackPatterns) @@ -295,6 +378,9 @@ applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) return signalPassFailure(); } + if (testTrackingListener) + if (failed(testTrackingListenerReplacements(rootOp))) + return signalPassFailure(); } namespace mlir { diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -855,7 +855,9 @@ "//mlir:Pass", "//mlir:SCFDialect", "//mlir:TensorDialect", + "//mlir:TensorTransformOps", "//mlir:TensorTransforms", + "//mlir:TransformDialect", "//mlir:Transforms", ], )