diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -33,14 +33,15 @@ namespace transform { namespace gpu { -/// Searches `scf.forall` ops nested under `target` and maps each such -/// op to GPU threads. Mapping is one-to-one and the induction variables of -/// `scf.forall` are rewritten to gpu.thread_id according to the -/// thread_dim_apping attribute. Sibling `scf.forall` are supported in -/// which case, the union of the number of threads is computed and may result in -/// predication. Dynamic, `scf.forall` trip counts are currently not -/// supported. Dynamic block dim sizes are currently not supported. -DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl( +/// Search `scf.forall` ops nested under `target` and map each such op to GPU +/// threads. Mapping is one-to-one and the induction variables of `scf.forall` +/// are rewritten to gpu.thread_id according to the thread_dim_mapping +/// attribute. +/// Sibling `scf.forall` are supported in which case, the union of the number of +/// threads is computed and may result in predication. +/// Dynamic, `scf.forall` trip counts are currently not supported. +/// Dynamic block dim sizes are currently not supported. +DiagnosedSilenceableFailure mapNestedForallToThreadsImpl( RewriterBase &rewriter, Operation *target, const SmallVectorImpl &blockDim, function_ref &)> @@ -48,19 +49,19 @@ bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes); -/// Maps the top level `scf.forall` op to GPU Thread Blocks. Mapping is -/// one-to-one and the induction variables of `scf.forall` are rewritten -/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic, -/// `scf.forall` trip counts are currently not supported. Dynamic block -/// dim sizes are currently not supported. -DiagnosedSilenceableFailure mapForeachToBlocksImpl( +/// Map the top level `scf.forall` op to GPU Thread Blocks. +/// Mapping is one-to-one and the induction variables of `scf.forall` are +/// rewritten to gpu.block_id according to the thread_dim_apping attribute. +/// Dynamic, `scf.forall` trip counts are currently not supported. +/// Dynamic block dim sizes are currently not supported. +DiagnosedSilenceableFailure mapForallToBlocksImpl( RewriterBase &rewriter, scf::ForallOp forallOp, function_ref &)> blockIdGenerator, SmallVectorImpl &gridDims, TransformOpInterface transformOp, const ArrayRef &mappingAttributes); -/// Finds the top level scf::ForallOp of given target. +/// Find the unique top level scf::ForallOp within a given target op. DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -15,8 +15,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -def MapNestedForeachToThreads : - Op linalgOpToLoops(PatternRewriter &rewriter, +FailureOr linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp); /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`. -FailureOr linalgOpToParallelLoops(PatternRewriter &rewriter, +FailureOr linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp); /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`. -FailureOr linalgOpToAffineLoops(PatternRewriter &rewriter, +FailureOr linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp); /// Creates a number of ranges equal to the number of non-zero in `tileSizes`. @@ -818,7 +818,7 @@ LinalgOp resultCombiningLinalgOp; }; FailureOr -splitReduction(PatternRewriter &b, LinalgOp op, +splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); @@ -870,7 +870,7 @@ /// return %4 : tensor<16x32xf32> /// ``` FailureOr -splitReductionByScaling(PatternRewriter &b, LinalgOp op, +splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); @@ -1097,7 +1097,7 @@ }; using OptimizeCopyFn = - std::function; + std::function; /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and /// InsertSliceOp. For now, only constant padding values are supported. @@ -1113,7 +1113,7 @@ protected: OptimizeCopyFn optimizeCopyFn; - Value createFillOrGenerateOp(PatternRewriter &rewriter, tensor::PadOp padOp, + Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const; }; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -31,6 +31,10 @@ /// } /// S1(N) S2(N-1) // Epilogue /// S2(N) // Epilogue +FailureOr pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const PipeliningOption &options); + +// TODO: such patterns should be auto-generated. class ForLoopPipeliningPattern : public OpRewritePattern { public: ForLoopPipeliningPattern(const PipeliningOption &options, @@ -42,7 +46,9 @@ } FailureOr returningMatchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const; + PatternRewriter &rewriter) const { + return pipelineForLoop(rewriter, forOp, options); + } protected: PipeliningOption options; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -18,7 +18,7 @@ namespace mlir { class Operation; -class PatternRewriter; +class RewriterBase; class TilingInterface; } // namespace mlir @@ -243,7 +243,7 @@ /// : tensor<7x4xf32> -> tensor<7xf32> /// ``` FailureOr -tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op, +tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize); } // namespace scf diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -152,7 +152,7 @@ // peeled. This takes the original operation, an i1 predicate value and the // pattern rewriter. using PredicateOpFn = - std::function; + std::function; PredicateOpFn predicateFn = nullptr; // TODO: add option to decide if the prologue should be peeled. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h +++ /dev/null @@ -1,28 +0,0 @@ -//===- TransformUtils.h - Transform Dialect Utils ---------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H -#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H - -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace transform { - -/// A simple pattern rewriter that can be constructed from a context. This is -/// necessary to apply patterns to a specific op locally. -class TrivialPatternRewriter : public PatternRewriter { -public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} -}; - -} // namespace transform -} // namespace mlir - -#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/IRMapping.h" using namespace mlir; @@ -26,13 +25,13 @@ /// Check if given mapping attributes are one of the desired attributes static DiagnosedSilenceableFailure checkAttributeType(ArrayRef threadMappingAttributes, - const std::optional &foreachMapping, + const std::optional &forallMapping, std::optional transformOp) { - if (!foreachMapping.has_value()) + if (!forallMapping.has_value()) return transformOp->emitSilenceableError() << "mapping must be present"; DenseSet seen; - for (Attribute map : foreachMapping->getValue()) { + for (Attribute map : forallMapping->getValue()) { if (!llvm::is_contained(threadMappingAttributes, map)) { return transformOp->emitDefiniteFailure() << "mapping must be one of " << threadMappingAttributes; @@ -124,7 +123,7 @@ /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional gridDimX = std::nullopt, std::optional gridDimY = std::nullopt, @@ -165,10 +164,10 @@ } //===----------------------------------------------------------------------===// -// MapForeachToBlocks +// MapForallToBlocks //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl( +DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( RewriterBase &rewriter, scf::ForallOp forallOp, function_ref &)> blockIdGenerator, @@ -262,7 +261,7 @@ if (forallOp->getParentOfType()) return WalkResult::advance(); if (topLevelForallOp) - // TODO: Handle multiple foreach if there is no dependences between them + // TODO: Handle multiple forall if they are independent. return WalkResult::interrupt(); topLevelForallOp = forallOp; return WalkResult::advance(); @@ -274,14 +273,13 @@ return DiagnosedSilenceableFailure::success(); } -/// This is a helper that is only used in -/// rewriteTopLevelForallToGpuBlocks. It generates GPU dialects -/// block_id. -static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForallOp foreachOp, - SmallVectorImpl &blockOps) { - Location loc = foreachOp->getLoc(); +/// This is a helper that is only used in rewriteTopLevelForallToGpuBlocks. +/// It generates GPU dialect block_id. +static void createGpuBlockIds(RewriterBase &rewriter, scf::ForallOp forallOp, + SmallVectorImpl &blockOps) { + Location loc = forallOp->getLoc(); OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(foreachOp); + rewriter.setInsertionPoint(forallOp); IndexType indexType = rewriter.getIndexType(); blockOps = SmallVector{ rewriter.create(loc, indexType, Dimension::x), @@ -290,11 +288,11 @@ } DiagnosedSilenceableFailure -transform::MapForeachToBlocks::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +transform::MapForallToBlocks::applyToOne(Operation *target, + ApplyToEachResultList &results, + transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -339,8 +337,8 @@ diag = checkAttributeType(blockMappingAttributes, topLevelForallOp.getMapping(), transformOp); if (diag.succeeded()) - diag = mlir::transform::gpu::mapForeachToBlocksImpl( - rewriter, topLevelForallOp, generateGpuBlockIds, gridDim, transformOp, + diag = mlir::transform::gpu::mapForallToBlocksImpl( + rewriter, topLevelForallOp, createGpuBlockIds, gridDim, transformOp, blockMappingAttributes); if (diag.succeeded()) { diag = alterGpuLaunch(rewriter, gpuLaunch, @@ -353,7 +351,7 @@ } //===----------------------------------------------------------------------===// -// MapNestedForeachToThreads +// MapNestedForallToThreads //===----------------------------------------------------------------------===// /// Searches `scf.forall` ops nested under `target` and maps each such @@ -497,7 +495,7 @@ return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl( +DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( RewriterBase &rewriter, Operation *target, const SmallVectorImpl &blockDim, function_ref &)> @@ -527,7 +525,7 @@ return diag; } -DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne( +DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( Operation *target, ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); @@ -548,7 +546,7 @@ } MLIRContext *ctx = getContext(); - TrivialPatternRewriter rewriter(ctx); + IRRewriter rewriter(ctx); rewriter.setInsertionPoint(target); SmallVector threadMappingAttributes = { @@ -565,7 +563,7 @@ rewriter.create(forallOp->getLoc(), indexType, Dimension::z)}); }; - diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( + diag = mlir::transform::gpu::mapNestedForallToThreadsImpl( rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(), transformOp, threadMappingAttributes); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" -#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineMap.h" @@ -68,6 +67,13 @@ // Apply the pattern directly to the op. PatternTy pattern(operation->getContext(), std::forward(args)...); + // We want to discourage direct use of PatternRewriter in APIs but In this + // very specific case, an IRRewriter is not enough. + struct TrivialPatternRewriter : public PatternRewriter { + public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} + }; TrivialPatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); @@ -293,7 +299,7 @@ if (!tilingInterfaceOp) return transformOp->emitError("only TilingInterface ops are supported"); - TrivialPatternRewriter rewriter(target->getContext()); + IRRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr tiledResults = applyFn(tilingInterfaceOp); @@ -377,7 +383,7 @@ tileSizes.size() - llvm::count(tileSizes, 0), transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr { - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); @@ -761,7 +767,9 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - FailureOr generic = tryApply(target); + IRRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { results.push_back(generic->getOperation()); return DiagnosedSilenceableFailure::success(); @@ -783,7 +791,7 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrivialPatternRewriter rewriter(target->getContext()); + IRRewriter rewriter(target->getContext()); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), @@ -1867,7 +1875,7 @@ if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); - TrivialPatternRewriter rewriter(target->getContext()); + IRRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -1966,7 +1974,7 @@ return tileSizes; }); SmallVector emptyTileSizes; - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -2018,7 +2026,7 @@ TransformState &state) { // Collect the dynamic split points if provided. ArrayRef payload = state.getPayloadOps(getTarget()); - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -2225,7 +2233,7 @@ unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -2265,7 +2273,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), @@ -2308,7 +2316,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); @@ -2495,7 +2503,7 @@ } tilingOptions.setInterchange(getInterchange()); - TrivialPatternRewriter rewriter(op->getContext()); + IRRewriter rewriter(op->getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2888,7 +2896,7 @@ } tilingOptions.setInterchange(getInterchange()); - TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext()); + IRRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -175,8 +175,8 @@ /// Replace the index operations in the body of the loop nest by the matching /// induction variables. -static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, - PatternRewriter &rewriter, +static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, + LinalgOp linalgOp, ArrayRef loopOps) { // Extract the induction variables of the loop nest from outer to inner. SmallVector allIvs; @@ -206,7 +206,7 @@ } template -static FailureOr linalgOpToLoopsImpl(PatternRewriter &rewriter, +static FailureOr linalgOpToLoopsImpl(RewriterBase &rewriter, LinalgOp linalgOp) { using LoadOpTy = std::conditional_t::value, AffineLoadOp, memref::LoadOp>; @@ -247,7 +247,7 @@ } LinalgLoops loops(loopSet.begin(), loopSet.end()); // Replace all index operations in the loop body. - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); + replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops); return loops; } @@ -367,20 +367,19 @@ /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. FailureOr -mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, - LinalgOp linalgOp) { +mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. -FailureOr mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, +FailureOr mlir::linalg::linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); } /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. FailureOr -mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, +mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter, LinalgOp linalgOp) { return linalgOpToLoopsImpl(rewriter, linalgOp); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -28,7 +28,7 @@ using namespace mlir::linalg; FailureOr mlir::linalg::splitReduction( - PatternRewriter &b, LinalgOp op, + RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); @@ -238,7 +238,7 @@ /// Core rewrite implementation. FailureOr mlir::linalg::splitReductionByScaling( - PatternRewriter &b, LinalgOp op, + RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -857,7 +857,7 @@ /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. Value GeneralizePadOpPattern::createFillOrGenerateOp( - PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, + RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const { auto padValue = padOp.getConstantPaddingValue(); if (padValue) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -31,8 +31,8 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include #include +#include using namespace mlir; using namespace mlir::linalg; @@ -1535,7 +1535,7 @@ /// Vectorize the copying of a tensor::PadOp's source. This is possible if /// each dimension size is statically know in the source type or the result /// type (or both). - static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, + static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest) { auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); @@ -2592,15 +2592,17 @@ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsMulAcc( - rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); } } // Its possible we failed to create the Fma. if (!llvm::all_of(resVals, [](Value v) { return v; })) { // Manually revert (in reverse order) to avoid leaving a bad IR state. - for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) + for (auto &collection : + {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) for (Value v : collection) rewriter.eraseOp(v.getDefiningOp()); return rewriter.notifyMatchFailure(op, "failed to create FMA"); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; @@ -89,7 +88,7 @@ for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -190,10 +189,10 @@ getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - TrivialPatternRewriter rewriter(getContext()); + IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = - pattern.returningMatchAndRewrite(target, rewriter); + scf::pipelineForLoop(rewriter, target, options); if (succeeded(patternResult)) { results.push_back(*patternResult); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -62,13 +62,13 @@ bool initializeLoopInfo(ForOp op, const PipeliningOption &options); /// Emits the prologue, this creates `maxStage - 1` part which will contain /// operations from stages [0; i], where i is the part index. - void emitPrologue(PatternRewriter &rewriter); + void emitPrologue(RewriterBase &rewriter); /// Gather liverange information for Values that are used in a different stage /// than its definition. llvm::MapVector analyzeCrossStageValues(); scf::ForOp createKernelLoop( const llvm::MapVector &crossStageValues, - PatternRewriter &rewriter, + RewriterBase &rewriter, llvm::DenseMap, unsigned> &loopArgMap); /// Emits the pipelined kernel. This clones loop operations following user /// order and remaps operands defined in a different stage as their use. @@ -76,10 +76,10 @@ scf::ForOp newForOp, const llvm::MapVector &crossStageValues, const llvm::DenseMap, unsigned> &loopArgMap, - PatternRewriter &rewriter); + RewriterBase &rewriter); /// Emits the epilogue, this creates `maxStage - 1` part which will contain /// operations from stages [i; maxStage], where i is the part index. - llvm::SmallVector emitEpilogue(PatternRewriter &rewriter); + llvm::SmallVector emitEpilogue(RewriterBase &rewriter); }; bool LoopPipelinerInternal::initializeLoopInfo( @@ -172,7 +172,7 @@ return clone; } -void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) { +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { // Initialize the iteration argument to the loop initiale values. for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); @@ -244,7 +244,7 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( const llvm::MapVector &crossStageValues, - PatternRewriter &rewriter, + RewriterBase &rewriter, llvm::DenseMap, unsigned> &loopArgMap) { // Creates the list of initial values associated to values used across // stages. The initial values come from the prologue created above. @@ -299,7 +299,7 @@ const llvm::MapVector &crossStageValues, const llvm::DenseMap, unsigned> &loopArgMap, - PatternRewriter &rewriter) { + RewriterBase &rewriter) { valueMapping.clear(); // Create the kernel, we clone instruction based on the order given by @@ -380,7 +380,7 @@ } if (predicates[useStage]) { - newOp = predicateFn(newOp, predicates[useStage], rewriter); + newOp = predicateFn(rewriter, newOp, predicates[useStage]); // Remap the results to the new predicated one. for (auto values : llvm::zip(op->getResults(), newOp->getResults())) mapping.map(std::get<0>(values), std::get<1>(values)); @@ -430,7 +430,7 @@ } llvm::SmallVector -LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) { +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) { llvm::SmallVector returnValues(forOp->getNumResults()); // Emit different versions of the induction variable. They will be // removed by dead code if not used. @@ -495,9 +495,8 @@ } // namespace -FailureOr ForLoopPipeliningPattern::returningMatchAndRewrite( - ForOp forOp, PatternRewriter &rewriter) const { - +FailureOr mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const PipeliningOption &options) { LoopPipelinerInternal pipeliner; if (!pipeliner.initializeLoopInfo(forOp, options)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -398,7 +398,7 @@ } FailureOr -mlir::scf::tileReductionUsingScf(PatternRewriter &b, +mlir::scf::tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize) { Location loc = op.getLoc(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" -#include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -114,7 +113,14 @@ } PatternApplicator applicator(it->second); - transform::TrivialPatternRewriter rewriter(root->getContext()); + // We want to discourage direct use of PatternRewriter in APIs but In this + // very specific case, an IRRewriter is not enough. + struct TrivialPatternRewriter : public PatternRewriter { + public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} + }; + TrivialPatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) @@ -453,7 +459,7 @@ DiagnosedSilenceableFailure transform::GetDefiningOp::apply(transform::TransformResults &results, - transform::TransformState &state) { + transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { if (v.isa()) { diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file -canonicalize -cse --verify-diagnostics %s -func.func @map_nested_foreach_to_threads_not_gpu_launch() -> () { +func.func @map_nested_forall_to_threads_not_gpu_launch() -> () { %1 = tensor.empty() : tensor<4xf32> return } @@ -8,12 +8,12 @@ ^bb0(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{Given target is not gpu.launch}} - %1 = transform.gpu.map_nested_foreach_to_threads %funcop + %1 = transform.gpu.map_nested_forall_to_threads %funcop } // ----- -func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { +func.func @map_nested_forall_to_threads_excessive_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c900 = arith.constant 900 : index %c9 = arith.constant 9 : index @@ -49,12 +49,12 @@ %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{Trying to launch a GPU kernel with gridDim = (1, 1, 1) blockDim = (1200, 9, 1). It is larger than the limits.}} // expected-note @below {{"blockDim" is very large}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [1200, 9, 1] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [1200, 9, 1] } } // ----- -func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { +func.func @map_nested_forall_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c900 = arith.constant 900 : index %c9 = arith.constant 9 : index @@ -90,12 +90,12 @@ ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{The requested GPU threads are fewer than the number of loop trip counts. Try to tile scf.forall before mapping or set small blockDim.}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] } } // ----- -func.func @map_nested_foreach_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token, %c9 : index, %c7 : index) -> memref<2 x 32 x f32> { +func.func @map_nested_forall_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token, %c9 : index, %c7 : index) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c900 = arith.constant 900 : index %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) @@ -116,12 +116,12 @@ ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{unsupported dynamic blockdim size}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] } } // ----- -func.func @map_nested_foreach_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: tensor<32x32xf32>, %z: tensor<32x32xf32>, %stream : !gpu.async.token) { +func.func @map_nested_forall_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: tensor<32x32xf32>, %z: tensor<32x32xf32>, %stream : !gpu.async.token) { %one = arith.constant 1 : index %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) @@ -135,16 +135,16 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !pdl.operation - %foreach, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread, #gpu.thread, #gpu.thread ] ) + %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread, #gpu.thread, #gpu.thread ] ) %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{only bufferized scf.forall lowers to gpu.thread_id}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] } } // ----- -func.func @map_foreach_to_blocks_not_gpu_launch() -> () { +func.func @map_forall_to_blocks_not_gpu_launch() -> () { // expected-note @below {{when applied to this payload op}} %1 = tensor.empty() : tensor<4xf32> return @@ -153,12 +153,12 @@ ^bb0(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{Given target is not gpu.launch}} - %1 = transform.gpu.map_foreach_to_blocks %funcop + %1 = transform.gpu.map_forall_to_blocks %funcop } // ----- -func.func @map_foreach_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { +func.func @map_forall_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c900 = arith.constant 900 : index %c9 = arith.constant 9 : index @@ -190,13 +190,13 @@ ^bb0(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{could not find a unique topLevel scf.forall}} - %1 = transform.gpu.map_foreach_to_blocks %funcop + %1 = transform.gpu.map_forall_to_blocks %funcop } // ----- // expected-note @below {{when applied to this payload op}} -func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { +func.func @map_forall_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c65537 = arith.constant 65536 : index %c9 = arith.constant 9 : index @@ -223,12 +223,12 @@ ^bb0(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{could not find a unique topLevel scf.forall}} - %1 = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch } + %1 = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch } } // ----- -func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { +func.func @map_forall_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { %one = arith.constant 1 : index %c65535 = arith.constant 65535 : index scf.forall (%i, %j) in (%c65535, %c65535) { @@ -244,7 +244,7 @@ ^bb0(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{Trying to launch a GPU kernel with gridDim = (65535, 65535, 1) blockDim = (1, 1, 1). It is larger than the limits.}} - %1 = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch } + %1 = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch } } // ----- @@ -271,7 +271,7 @@ ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{#gpu.thread is duplicated, cannot map different loops to the same processor}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]} + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 32]} } // ----- @@ -301,5 +301,5 @@ ^bb1(%arg0: !pdl.operation): %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{transform.structured.tile_to_forall_op failed to apply}} - %foreach, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread, #gpu.thread, #gpu.thread ] ) + %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread, #gpu.thread, #gpu.thread ] ) } diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -33,7 +33,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]} + transform.gpu.map_forall_to_blocks %funcop { gridDim = [12, 9]} } // ----- @@ -87,7 +87,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] } } // ----- @@ -126,8 +126,8 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation - %gpuLaunch = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch } - transform.gpu.map_nested_foreach_to_threads %gpuLaunch { blockDim = [32, 4, 1] } + %gpuLaunch = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch } + transform.gpu.map_nested_forall_to_threads %gpuLaunch { blockDim = [32, 4, 1] } } // ----- @@ -160,7 +160,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } } // ----- @@ -192,7 +192,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32]} + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32]} } // ----- @@ -228,7 +228,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } } // ----- @@ -267,5 +267,5 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } + transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] } } diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -150,8 +150,8 @@ /// Helper to generate "predicated" version of `op`. For simplicity we just /// wrap the operation in a scf.ifOp operation. - static Operation *predicateOp(Operation *op, Value pred, - PatternRewriter &rewriter) { + static Operation *predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { Location loc = op->getLoc(); auto ifOp = rewriter.create(loc, op->getResultTypes(), pred, true);