diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -284,9 +284,10 @@ public: TransformOptions() {} - /// Requests computationally expensive checks of the transform and payload IR - /// well-formedness to be performed before each transformation. In particular, - /// these ensure that the handles still point to valid operations when used. + /// Requests computationally expensive checks of the transform and payload + /// IR well-formedness to be performed before each transformation. In + /// particular, these ensure that the handles still point to valid + /// operations when used. TransformOptions &enableExpensiveChecks(bool enable = true) { expensiveChecksEnabled = enable; return *this; @@ -313,16 +314,16 @@ /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to /// which transformations apply are referred to as payload IR. The state thus -/// contains the many-to-many mapping between values defined in the transform IR -/// ops and payload IR ops. The "expensive-checks" option can be passed to +/// contains the many-to-many mapping between values defined in the transform +/// IR ops and payload IR ops. The "expensive-checks" option can be passed to /// the constructor at transformation execution time that transform IR values /// used as operands by a transform IR operation are not associated with /// dangling pointers to payload IR operations that are known to have been -/// erased by previous transformation through the same or a different transform -/// IR value. +/// erased by previous transformation through the same or a different +/// transform IR value. /// -/// A reference to this class is passed as an argument to "apply" methods of the -/// transform op interface. Thus the "apply" method can call +/// A reference to this class is passed as an argument to "apply" methods of +/// the transform op interface. Thus the "apply" method can call /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations /// associated with its operand and subject to transformation. The method is /// expected to populate the `TransformResults` class instance in order to @@ -331,11 +332,11 @@ /// /// When applying transform IR operations with regions, the client is expected /// to create a RegionScope RAII object to create a new "stack frame" for -/// values defined inside the region. The mappings from and to these values will -/// be automatically dropped when the object goes out of scope, typically at the -/// end of the "apply" function of the parent operation. If a region contains -/// blocks with arguments, the client can map those arguments to payload IR ops -/// using "mapBlockArguments". +/// values defined inside the region. The mappings from and to these values +/// will be automatically dropped when the object goes out of scope, typically +/// at the end of the "apply" function of the parent operation. If a region +/// contains blocks with arguments, the client can map those arguments to +/// payload IR ops using "mapBlockArguments". class TransformState { /// Mapping between a Value in the transform IR and the corresponding set of /// operations in the payload IR. @@ -362,17 +363,18 @@ /// typically helpful for transformations that apply globally. Operation *getTopLevel() const; - /// Returns the list of ops that the given transform IR value corresponds to. - /// This is helpful for transformations that apply to a particular handle. + /// Returns the list of ops that the given transform IR value corresponds + /// to. This is helpful for transformations that apply to a particular + /// handle. ArrayRef getPayloadOps(Value value) const; - /// Populates `handles` with all handles pointing to the given Payload IR op. - /// Returns success if such handles exist, failure otherwise. + /// Populates `handles` with all handles pointing to the given Payload IR + /// op. Returns success if such handles exist, failure otherwise. LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl &handles) const; - /// Applies the transformation specified by the given transform op and updates - /// the state accordingly. + /// Applies the transformation specified by the given transform op and + /// updates the state accordingly. DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform); /// Records the mapping between a block argument in the transform IR and a @@ -380,8 +382,8 @@ /// blocks of the currently processed transform IR region, typically after a /// region scope is defined. /// - /// Returns failure if the payload does not satisfy the conditions associated - /// with the type of the handle value. + /// Returns failure if the payload does not satisfy the conditions + /// associated with the type of the handle value. LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef operations) { #if LLVM_ENABLE_ABI_BREAKING_CHECKS diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h @@ -0,0 +1,29 @@ +//===- TransformUtils.h - Transform Dialect Utils ----------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H +#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace transform { + +/// A simple pattern rewriter that can be constructed from a context. This is +/// necessary to apply patterns to a specific op locally. +class TrivialPatternRewriter : public PatternRewriter { +public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} +}; + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -11,17 +11,10 @@ #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" using namespace mlir; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -16,24 +16,13 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Value.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" using namespace mlir; using namespace mlir::gpu; using namespace mlir::transform; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - /// Check if given mapping attributes are one of the desired attributes static DiagnosedSilenceableFailure checkAttributeType(ArrayRef threadMappingAttributes, @@ -135,7 +124,7 @@ /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(SimpleRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, Optional gridDimX = llvm::None, Optional gridDimY = llvm::None, @@ -305,7 +294,7 @@ SmallVectorImpl &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -555,7 +544,7 @@ } MLIRContext *ctx = getContext(); - SimpleRewriter rewriter(ctx); + TrivialPatternRewriter rewriter(ctx); rewriter.setInsertionPoint(target); SmallVector threadMappingAttributes = { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -41,14 +42,6 @@ return result; } -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the @@ -65,7 +58,7 @@ // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); - SimpleRewriter rewriter(operation->getContext()); + TrivialPatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) @@ -125,7 +118,7 @@ if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); @@ -209,7 +202,7 @@ tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); @@ -601,7 +594,7 @@ results.push_back(target); return DiagnosedSilenceableFailure(success()); } - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, interchangeVector); if (failed(res)) @@ -875,7 +868,7 @@ if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - SimpleRewriter rewriter(target->getContext()); + TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -974,7 +967,7 @@ return tileSizes; }); SmallVector emptyTileSizes; - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -993,7 +986,7 @@ TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -1122,8 +1115,7 @@ } LogicalResult SplitOp::verify() { - if ((static_cast(getStaticSplitPoint()) != - ShapedType::kDynamic) ^ + if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; @@ -1172,7 +1164,7 @@ unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -1195,7 +1187,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector sizes; @@ -1223,7 +1215,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector numThreads = extractFromI64ArrayAttr(getNumThreads()); SmallVector numThreadResults; @@ -1321,7 +1313,7 @@ } tilingOptions.setInterchange(getInterchange()); - SimpleRewriter rewriter(linalgOp.getContext()); + TrivialPatternRewriter rewriter(linalgOp.getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(linalgOp.getOperation()), tilingOptions); @@ -1714,7 +1706,7 @@ } tilingOptions.setInterchange(getInterchange()); - SimpleRewriter rewriter(tilingInterfaceOp.getContext()); + TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -16,18 +16,11 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; -namespace { -/// A simple pattern rewriter that implements no special logic. -class SimpleRewriter : public PatternRewriter { -public: - SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} -}; -} // namespace - //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// @@ -97,7 +90,7 @@ for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -201,7 +194,7 @@ getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - SimpleRewriter rewriter(getContext()); + TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = pattern.returningMatchAndRewrite(target, rewriter);