diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -33,7 +33,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -74,7 +74,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -96,7 +96,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -124,7 +124,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -149,7 +149,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -218,7 +218,7 @@ let extraClassDeclaration = [{ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::linalg::LinalgOp target); + ::mlir::linalg::LinalgOp target, TransformState &state); }]; } @@ -275,7 +275,8 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr applyToOne(::mlir::Operation *target); + ::mlir::FailureOr applyToOne( + ::mlir::Operation *target, TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -88,7 +88,8 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( + ::mlir::scf::ForOp loop, TransformState &state); }]; } @@ -115,7 +116,8 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop); + ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( + ::mlir::scf::ForOp loop, TransformState &state); }]; } @@ -137,7 +139,8 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop); + ::mlir::LogicalResult applyToOne( + ::mlir::scf::ForOp loop, TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -582,9 +582,9 @@ /// transformation to a single operation handle and producing one or multiple /// operation handles. /// The op must implement a method with one of the following signatures: -/// - FailureOr applyToOne(OpTy) -/// - FailureOr> applyToOne(OpTy) -/// - LogicalResult applyToOne(OpTy) +/// - FailureOr applyToOne(OpTy, state) +/// - FailureOr>applyToOne(OpTy, state) +/// - LogicalResult applyToOne(OpTy, state) /// to perform a transformation that is applied in turn to all payload IR /// operations that correspond to the handle of the transform IR operation. /// In the functions above, OpTy is either Operation * or a concrete payload IR @@ -811,7 +811,7 @@ // produced. DiagnosedSilenceableFailure result = detail::applyTransformToEach( targets, results, [&](TransformOpType specificOp) { - return static_cast(this)->applyToOne(specificOp); + return static_cast(this)->applyToOne(specificOp, state); }); if (!result.succeeded()) return result; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -76,7 +76,8 @@ // DecomposeOp //===----------------------------------------------------------------------===// -FailureOr transform::DecomposeOp::applyToOne(LinalgOp target) { +FailureOr transform::DecomposeOp::applyToOne(LinalgOp target, + TransformState &state) { FailureOr windowed = tryApply(target); if (succeeded(windowed)) @@ -220,7 +221,8 @@ // GeneralizeOp //===----------------------------------------------------------------------===// -FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target) { +FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target, + TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) return target; @@ -236,7 +238,8 @@ // InterchangeOp //===----------------------------------------------------------------------===// -FailureOr transform::InterchangeOp::applyToOne(LinalgOp target) { +FailureOr +transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. @@ -272,7 +275,8 @@ // PadOp //===---------------------------------------------------------------------===// -FailureOr transform::PadOp::applyToOne(LinalgOp target) { +FailureOr transform::PadOp::applyToOne(LinalgOp target, + TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractI64Array(getPackPaddings())) @@ -377,7 +381,8 @@ // ScalarizeOp //===----------------------------------------------------------------------===// -FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target) { +FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target, + TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile @@ -399,7 +404,8 @@ //===----------------------------------------------------------------------===// FailureOr> -transform::SplitReductionOp::applyToOne(LinalgOp target) { +transform::SplitReductionOp::applyToOne(LinalgOp target, + TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return std::pair(getSplitFactor(), getInsertSplitDimension()); @@ -455,7 +461,8 @@ // VectorizeOp //===----------------------------------------------------------------------===// -FailureOr VectorizeOp::applyToOne(Operation *target) { +FailureOr VectorizeOp::applyToOne(Operation *target, + TransformState &state) { if (!target->hasTrait()) { InFlightDiagnostic diag = emitOpError() << "applies only to isolated-from-above targets"; diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -127,7 +127,8 @@ // LoopPeelOp //===----------------------------------------------------------------------===// -FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop) { +FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop, + TransformState &state) { scf::ForOp result; IRRewriter rewriter(loop->getContext()); LogicalResult status = @@ -180,7 +181,8 @@ } } -FailureOr transform::LoopPipelineOp::applyToOne(scf::ForOp loop) { +FailureOr +transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) { scf::PipeliningOption options; options.getScheduleFn = [this](scf::ForOp forOp, @@ -203,7 +205,8 @@ // LoopUnrollOp //===----------------------------------------------------------------------===// -LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) { +LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop, + TransformState &state) { if (failed(loopUnrollByFactor(loop, getFactor()))) return reportUnknownTransformError(loop); return success(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -227,7 +227,8 @@ } FailureOr> -mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) { +mlir::test::TestWrongNumberOfResultsOp::applyToOne( + Operation *, transform::TransformState &state) { return SmallVector{}; } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -140,7 +140,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target); + ::mlir::Operation *target, transform::TransformState &state); }]; }