diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -26,6 +26,53 @@ Transform_ParamType.predicate]>, "transform 'param' type or any handle type">; +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +def ApplyEraseUnnecessaryInputsPatternsOp : Op]> { + let description = [{ + Collects patterns that promote inputs to outputs and remove unused inputs of + `linalg.generic` ops. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { + let description = [{ + Collects patterns to fold unit-extent dimensions in operands/results of + linalg ops on tensors via reassociative reshape ops. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldUnitExtentDimsViaSlicesPatternsOp : Op]> { + let description = [{ + Collects patterns to fold unit-extent dimensions in operands/results of + linalg ops on tensors via rank-reducing slices. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyTilingCanonicalizationPatternsOp : Op]> { + let description = [{ + Collects canonicalization patterns relevant to apply after tiling patterns. + }]; + + let assemblyFormat = "attr-dict"; +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -15,6 +15,81 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyExpandOpsPatternsOp : Op]> { + let description = [{ + Collects patterns to rewrite ops within the memref dialect. + + - Converts `atomic_rmw` that cannot be lowered to a simple atomic op with + AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to + `memref.generic_atomic_rmw` with the expanded code. + - Converts `memref.reshape` that has a target shape of a statically-known + size to `memref.reinterpret_cast`. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyExpandStridedMetadataPatternsOp : Op]> { + let description = [{ + Collects patterns for expanding memref operations that modify the metadata + (sizes, offset, strides) of a memref into easier to analyze constructs. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyExtractAddressComputationsPatternsOp : Op]> { + let description = [{ + Collects patterns for extracting address computations from operations + with memory accesses such that these memory accesses use only a base + pointer. + + For instance, + ```mlir + memref.load %base[%off0, ...] + ``` + + Will be rewritten in: + ```mlir + %new_base = memref.subview %base[%off0,...][1,...][1,...] + memref.load %new_base[%c0,...] + ``` + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldMemrefAliasOpsPatternsOp : Op]> { + let description = [{ + Collects patterns for folding memref aliasing ops (memref.subview) into + consumer load/store ops (affine.load, memref.load, nvgpu.ldmatrix, + vector.load, vector.transfer_read, affine.store, memref.store, etc.) and + other ops (e.g., memref.subview). + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op]> { + let description = [{ + Collects patterns that resolve `memref.dim` operations with values that are + defined by operations that implement the `ReifyRankedShapedTypeOpInterface`, + in terms of shapes of its input operands. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">; def MemRefMultiBufferOp : Op { - let summary = "Extract address computations from memory accesses"; - let description = [{ - Transformation that extracts address computations from instructions - with memory accesses such that these memory accesses use only a base - pointer. - - For instance, - ```mlir - memref.load %base[%off0, ...] - ``` - - Will be rewritten in: - ```mlir - %new_base = memref.subview %base[%off0,...][1,...][1,...] - memref.load %new_base[%c0,...] - ``` - - Note: The current implementation requires that the input operation - is "isolated from above". - - #### Return modes - - This operation produces `definiteFailure` if the extraction fails for any - reason. - The operation always returns the handle to the target op that is expected - to be isolated from above. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs TransformHandleTypeInterface:$transformed); - - let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &transformResults, - ::mlir::transform::TransformState &state); - }]; -} - def MemRefMakeLoopIndependentOp : Op]> { + let description = [{ + Collects patterns for canonicalizing operations inside SCF loop bodies. + At the moment, only affine.min/max computations with iteration variables, + loop bounds and loop steps are canonicalized. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">; def GetParentForOp : Op]> { + let description = [{ + Indicates that redundant tensor.insert_slice rank reductions should be + dropped. E.g., cases where a tensor.extract_slice rank reduction immediately + follows an inverse tensor.insert_slice rank expansion. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldTensorEmptyPatternsOp : Op]> { + let description = [{ + Indicates that reassociative reshapes (tensor.collapse_shape / + tensor.expand_shape) should be folded with inverse rank expansions / rank + reductions (via tensor.insert_slice / tensor.extract_slice). + }]; + + let assemblyFormat = "attr-dict"; +} +def ApplyFoldIntoPackAndUnpackPatternsOp : Op]> { + let description = [{ + Indicates that operations like tensor.pad and tensor.extract_slice should + be folded into tensor.pack and tensor.unpack operations, respectively. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldTensorSubsetOpsPatternsOp : Op]> { + let description = [{ + Indicates that tensor.empty should be folded with tensor.extract_slice, + tensor.expand_shape and tensor.collapse_shape. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyMergeConsecutiveInsertExtractSlicePatternsOp : Op]> { + let description = [{ + Indicates that consecutive tensor.extract_slice/tensor.insert_slice ops + should be merged into a single op. These patterns are not canonicalizations + because the bufferization is sensitive to IR structure. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyReassociativeReshapeFoldingPatternsOp : Op]> { + let description = [{ + Indicates that reassociative reshapes (tensor.collapse_shape / + tensor.expand_shape) should be folded with inverse rank expansions / rank + reductions (via tensor.insert_slice / tensor.extract_slice). + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">; def MakeLoopIndependentOp diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -39,8 +39,8 @@ void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns); /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice -/// into one. These patterns are in in this separate entry point because the -/// bufferization is sensitive over IR structure, particularly those +/// into one. These patterns are in this separate entry point because the +/// bufferization is sensitive to IR structure, particularly those /// tensor.extract_slice and tensor.insert_slice ops for creating the slices. void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); @@ -55,7 +55,7 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); /// Populates `patterns` with patterns that fold tensor.empty with -/// tensor.[extract_slice|cast|expand_shape|collapse_shape]. +/// tensor.[extract_slice|expand_shape|collapse_shape]. void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); /// Populates `patterns` with patterns that fold operations like `tensor.pad` diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -26,7 +26,6 @@ namespace mlir { namespace transform { -class ApplyPatternsOp; enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; @@ -152,51 +151,9 @@ int64_t errorCounter = 0; }; -/// The PatternRegistry stores callbacks to functions that populate a -/// `RewritePatternSet`. Registered patterns can be applied with the -/// "transform.apply_patterns" op. -class PatternRegistry : public TransformDialectData { -public: - PatternRegistry(MLIRContext *ctx) : TransformDialectData(ctx), builder(ctx) {} - - /// A function that populates a `RewritePatternSet`. - using PopulatePatternsFn = std::function; - /// A function that populates a `RewritePatternSet` with a specified benefit. - using PopulatePatternsWithBenefitFn = - std::function; - - /// Registers patterns with the specified identifier. The identifier should - /// be prefixed with the dialect to which the patterns belong. - void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn); - - /// Registers patterns with the specified identifier. The identifier should - /// be prefixed with the dialect to which the patterns belong. The pattern - /// benefit is currently ignored. - void registerPatterns(StringRef identifier, - PopulatePatternsWithBenefitFn &&fn); - -protected: - friend class ApplyPatternsOp; - - /// Returns "true" if patterns are registered with the specified identifier. - bool hasPatterns(StringAttr identifier) const; - - /// Populates the given pattern set with the specified patterns. - void populatePatterns(StringAttr identifier, - RewritePatternSet &patternSet) const; - -private: - /// A builder for creating StringAttrs. - Builder builder; - - DenseMap patterns; -}; - } // namespace transform } // namespace mlir -MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) - #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -138,12 +138,9 @@ targeted op itself. The patterns that should be applied are specified in the graph region of - this op. They must implement the `PatternDescriptorOpInterface`. - - (Deprecated) In addition, patterns that were registered in the transform - dialect's `PatternRegistry` are available. "canonicalization" is a special - set of patterns that refers to all canonicalization patterns of all loaded - dialects. + this op. They must implement the `PatternDescriptorOpInterface`. The order + in which patterns are applied is unspecified; i.e., the ordering of ops in + the region of this op is irrelevant. This transform only reads the target handle and modifies the payload. If a pattern erases or replaces a tracked op, the mapping is updated accordingly. @@ -161,12 +158,12 @@ }]; let arguments = (ins - TransformHandleTypeInterface:$target, ArrayAttr:$patterns, + TransformHandleTypeInterface:$target, DefaultValuedAttr:$fail_on_payload_replacement_not_found); let results = (outs); let regions = (region MaxSizedRegion<1>:$region); - let assemblyFormat = "$patterns `to` $target $region attr-dict `:` type($target)"; + let assemblyFormat = "`to` $target $region attr-dict `:` type($target)"; let hasVerifier = 1; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -52,7 +52,7 @@ let assemblyFormat = "attr-dict"; } -def LowerBroadcastOp : Op]> { let description = [{ @@ -67,7 +67,7 @@ } // TODO: evolve lowering_strategy to proper enums. -def LowerContractionOp : Op]> { let description = [{ @@ -86,7 +86,7 @@ }]; } -def LowerMasksOp : Op]> { let description = [{ @@ -100,7 +100,7 @@ let assemblyFormat = "attr-dict"; } -def LowerMaskedTransfersOp : Op]> { let description = [{ @@ -114,7 +114,7 @@ let assemblyFormat = "attr-dict"; } -def MaterializeMasksOp : Op]> { let description = [{ @@ -129,7 +129,7 @@ } // TODO: evolve lowering_strategy to proper enums. -def LowerMultiReductionOp : Op]> { let description = [{ @@ -149,7 +149,7 @@ }]; } -def LowerOuterProductOp : Op]> { let description = [{ @@ -163,7 +163,29 @@ let assemblyFormat = "attr-dict"; } -def LowerShapeCastOp : Op]> { + let description = [{ + Indicates that vector.gather operations should be lowered to + finer-grained vector primitives. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyLowerScanPatternsOp : Op]> { + let description = [{ + Indicates that vector.scan operations should be lowered to + finer-grained vector primitives. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyLowerShapeCastPatternsOp : Op]> { let description = [{ @@ -177,7 +199,7 @@ let assemblyFormat = "attr-dict"; } -def LowerTransferOp : Op]> { let description = [{ @@ -196,7 +218,7 @@ } // TODO: evolve lowering_strategy to proper enums. -def LowerTransposeOp : Op]> { let description = [{ @@ -223,7 +245,7 @@ } // TODO: evolve split_transfer_strategy to proper enums. -def SplitTransferFullPartialOp : Op]> { let description = [{ @@ -244,7 +266,7 @@ }]; } -def TransferToScfOp : Op]> { let description = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -49,22 +49,6 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns( - "linalg.erase_unnecessary_inputs", - linalg::populateEraseUnnecessaryInputsPatterns); - registry.registerPatterns( - "linalg.fold_unit_extent_dims_via_slices", - linalg::populateFoldUnitExtentDimsViaSlicesPatterns); - registry.registerPatterns( - "linalg.fold_unit_extent_dims_via_reshapes", - linalg::populateFoldUnitExtentDimsViaReshapesPatterns); - registry.registerPatterns( - "linalg.tiling_canonicalization", - linalg::populateLinalgTilingCanonicalizationPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -141,9 +141,34 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateEraseUnnecessaryInputsPatterns(patterns); +} + +void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); +} + +void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); +} + +void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -27,6 +26,35 @@ #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyExpandOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExpandOpsPatterns(patterns); +} + +void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExpandStridedMetadataPatterns(patterns); +} + +void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExtractAddressComputationsPatterns(patterns); +} + +void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateFoldMemRefAliasOpPatterns(patterns); +} + +void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); +} + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// @@ -72,31 +100,6 @@ return DiagnosedSilenceableFailure::success(); } -//===----------------------------------------------------------------------===// -// MemRefExtractAddressComputationsOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::MemRefExtractAddressComputationsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, - transform::TransformState &state) { - if (!target->hasTrait()) { - auto diag = this->emitOpError("requires isolated-from-above targets"); - diag.attachNote(target->getLoc()) << "non-isolated target"; - return DiagnosedSilenceableFailure::definiteFailure(); - } - - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - memref::populateExtractAddressComputationsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return emitDefaultDefiniteFailure(target); - - results.push_back(target); - return DiagnosedSilenceableFailure::success(); -} - //===----------------------------------------------------------------------===// // MemRefMakeLoopIndependentOp //===----------------------------------------------------------------------===// @@ -162,23 +165,6 @@ #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("memref.expand_ops", - memref::populateExpandOpsPatterns); - registry.registerPatterns("memref.fold_memref_alias_ops", - memref::populateFoldMemRefAliasOpPatterns); - registry.registerPatterns( - "memref.resolve_ranked_shaped_type_result_dims", - memref::populateResolveRankedShapedTypeResultDimsPatterns); - registry.registerPatterns( - "memref.expand_strided_metadata", - memref::populateExpandStridedMetadataPatterns); - registry.registerPatterns( - "memref.extract_address_computations", - memref::populateExtractAddressComputationsPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -22,6 +22,15 @@ using namespace mlir; using namespace mlir::affine; +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + scf::populateSCFForLoopCanonicalizationPatterns(patterns); +} + //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// @@ -309,13 +318,6 @@ #define GET_OP_LIST #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns( - "scf.for_loop_canonicalization", - scf::populateSCFForLoopCanonicalizationPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -79,6 +79,40 @@ }); } +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); +} + +void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldTensorEmptyPatterns(patterns); +} + +void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldIntoPackAndUnpackPatterns(patterns); +} + +void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetOpPatterns(patterns); +} + +void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); +} + +void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + //===----------------------------------------------------------------------===// // MakeLoopIndependentOp //===----------------------------------------------------------------------===// @@ -144,26 +178,6 @@ #define GET_OP_LIST #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("tensor.fold_tensor_subset_ops", - tensor::populateFoldTensorSubsetOpPatterns); - registry.registerPatterns( - "tensor.merge_consecutive_insert_extract_slice", - tensor::populateMergeConsecutiveInsertExtractSlicePatterns); - registry.registerPatterns( - "tensor.drop_redundant_insert_slice_rank_expansion", - tensor::populateDropRedundantInsertSliceRankExpansionPatterns); - registry.registerPatterns( - "tensor.reassociative_reshape_folding", - tensor::populateReassociativeReshapeFoldingPatterns); - registry.registerPatterns("tensor.fold_tensor_empty", - tensor::populateFoldTensorEmptyPatterns); - registry.registerPatterns( - "tensor.fold_into_pack_and_unpack", - tensor::populateFoldIntoPackAndUnpackPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -32,8 +32,6 @@ using namespace mlir; -MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) - static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, @@ -217,37 +215,6 @@ ++errorCounter; } -//===----------------------------------------------------------------------===// -// PatternRegistry -//===----------------------------------------------------------------------===// - -void transform::PatternRegistry::registerPatterns(StringRef identifier, - PopulatePatternsFn &&fn) { - StringAttr attr = builder.getStringAttr(identifier); - assert(!patterns.contains(attr) && "patterns identifier is already in use"); - patterns.try_emplace(attr, std::move(fn)); -} - -void transform::PatternRegistry::registerPatterns( - StringRef identifier, PopulatePatternsWithBenefitFn &&fn) { - StringAttr attr = builder.getStringAttr(identifier); - assert(!patterns.contains(attr) && "patterns identifier is already in use"); - patterns.try_emplace(attr, [f = std::move(fn)](RewritePatternSet &patternSet) { - f(patternSet, /*benefit=*/1); - }); -} - -void transform::PatternRegistry::populatePatterns( - StringAttr identifier, RewritePatternSet &patternSet) const { - auto it = patterns.find(identifier); - assert(it != patterns.end() && "patterns not registered in registry"); - it->second(patternSet); -} - -bool transform::PatternRegistry::hasPatterns(StringAttr identifier) const { - return patterns.contains(identifier); -} - //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// @@ -440,11 +407,6 @@ // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); - const auto ®istry = getContext() - ->getLoadedDialect() - ->getExtraData(); - for (Attribute attr : getPatterns()) - registry.populatePatterns(attr.cast(), patterns); if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { cast(&op).populatePatterns( @@ -495,17 +457,6 @@ } LogicalResult transform::ApplyPatternsOp::verify() { - const auto ®istry = getContext() - ->getLoadedDialect() - ->getExtraData(); - for (Attribute attr : getPatterns()) { - auto strAttr = attr.dyn_cast(); - if (!strAttr) - return emitOpError() << "expected " << getPatternsAttrName() - << " to be an array of strings"; - if (!registry.hasPatterns(strAttr)) - return emitOpError() << "patterns not registered: " << strAttr.strref(); - } if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { if (!isa(&op)) { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -24,7 +24,7 @@ using namespace mlir::transform; //===----------------------------------------------------------------------===// -// ApplyRankReducingSubviewPatternsOp +// Apply...PatternsOp //===----------------------------------------------------------------------===// void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( @@ -32,29 +32,17 @@ vector::populateVectorTransferDropUnitDimsPatterns(patterns); } -//===----------------------------------------------------------------------===// -// ApplyTransferPermutationPatternsOp -//===----------------------------------------------------------------------===// - void transform::ApplyTransferPermutationPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerBroadcastOp -//===----------------------------------------------------------------------===// - -void transform::LowerBroadcastOp::populatePatterns( +void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorBroadcastLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerContractionOp -//===----------------------------------------------------------------------===// - -void transform::LowerContractionOp::populatePatterns( +void transform::ApplyLowerContractionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); @@ -63,37 +51,23 @@ /*disableOuterProductLowering=*/true); } -//===----------------------------------------------------------------------===// -// LowerMasksOp -//===----------------------------------------------------------------------===// - -void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyLowerMasksPatternsOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskOpLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerMaskedTransfersOp -//===----------------------------------------------------------------------===// - -void transform::LowerMaskedTransfersOp::populatePatterns( +void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); } -//===----------------------------------------------------------------------===// -// MaterializeMasksOp -//===----------------------------------------------------------------------===// - -void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskMaterializationPatterns(patterns, /*force32BitVectorIndices=*/false); } -//===----------------------------------------------------------------------===// -// LowerMultiReductionOp -//===----------------------------------------------------------------------===// - -void transform::LowerMultiReductionOp::populatePatterns( +void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); @@ -101,38 +75,33 @@ patterns, vectorTransformOptions.vectorMultiReductionLowering); } -//===----------------------------------------------------------------------===// -// LowerOuterProductOp -//===----------------------------------------------------------------------===// - -void transform::LowerOuterProductOp::populatePatterns( +void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorOuterProductLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerShapeCastOp -//===----------------------------------------------------------------------===// +void transform::ApplyLowerGatherPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorGatherLoweringPatterns(patterns); +} -void transform::LowerShapeCastOp::populatePatterns( +void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorShapeCastLoweringPatterns(patterns); + vector::populateVectorScanLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerTransferOp -//===----------------------------------------------------------------------===// +void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorShapeCastLoweringPatterns(patterns); +} -void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyLowerTransferPatternsOp::populatePatterns( + RewritePatternSet &patterns) { vector::populateVectorTransferLoweringPatterns(patterns, getMaxTransferRank()); } -//===----------------------------------------------------------------------===// -// LowerTransposeOp -//===----------------------------------------------------------------------===// - -void transform::LowerTransposeOp::populatePatterns( +void transform::ApplyLowerTransposePatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( @@ -148,22 +117,15 @@ } } -//===----------------------------------------------------------------------===// -// SplitTransferFullPartialOp -//===----------------------------------------------------------------------===// - -void transform::SplitTransferFullPartialOp::populatePatterns( +void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); } -//===----------------------------------------------------------------------===// -// TransferToScfOp -//===----------------------------------------------------------------------===// - -void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyTransferToScfPatternsOp::populatePatterns( + RewritePatternSet &patterns) { VectorTransferToSCFOptions vectorTransferToSCFOptions = VectorTransferToSCFOptions() .enableFullUnroll(getFullUnroll()) @@ -189,34 +151,6 @@ #define GET_OP_LIST #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("vector.outer_product_lowering", - populateVectorOuterProductLoweringPatterns); - registry.registerPatterns("vector.broadcast_lowering", - populateVectorBroadcastLoweringPatterns); - registry.registerPatterns("vector.mask_op_lowering", - populateVectorMaskOpLoweringPatterns); - registry.registerPatterns("vector.shape_cast_lowering", - populateVectorShapeCastLoweringPatterns); - registry.registerPatterns( - "vector.transfer_lowering", - [&](RewritePatternSet &set, PatternBenefit benefit) { - return populateVectorTransferLoweringPatterns( - set, /*maxTransferRank=*/std::nullopt, benefit); - }); - registry.registerPatterns( - "vector.transfer_permutation_map_lowering", - populateVectorTransferPermutationMapLoweringPatterns); - registry.registerPatterns("vector.scan_lowering", - populateVectorScanLoweringPatterns); - registry.registerPatterns("vector.vector_gather_lowering", - populateVectorGatherLoweringPatterns); - registry.registerPatterns( - "vector.mask_lowering_for_side_effecting_ops", - populateVectorMaskLoweringPatternsForSideEffectingOps); - }); } }; } // namespace diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -27,35 +27,35 @@ // TODO: group these lower-level controls into various properly named vector // lowering TD macros. - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } diff --git a/mlir/test/Dialect/MemRef/extract-address-computations.mlir b/mlir/test/Dialect/MemRef/extract-address-computations.mlir --- a/mlir/test/Dialect/MemRef/extract-address-computations.mlir +++ b/mlir/test/Dialect/MemRef/extract-address-computations.mlir @@ -24,9 +24,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op // Verify that the returned handle is usable. - transform.test_print_remark_at_operand %1, "transformed" : !transform.any_op + transform.test_print_remark_at_operand %0, "transformed" : !transform.any_op } // ----- @@ -50,8 +52,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -79,8 +82,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -105,8 +109,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- // For this test, we made the source memref fully dynamic. @@ -159,8 +164,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -197,8 +203,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -231,8 +238,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -266,8 +274,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -294,8 +303,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -328,8 +338,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- @@ -363,8 +374,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} // ----- // Same as test_transfer_write_op but with tensors. @@ -389,5 +401,6 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op -} + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -693,7 +693,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{expected children ops to implement PatternDescriptorOpInterface}} - transform.apply_patterns [] to %arg0 { + transform.apply_patterns to %arg0 { // expected-note @below {{op without interface}} transform.named_sequence @foo() } : !transform.any_op diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -15,29 +15,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op - // Add an attribute to %1, which is now mapped to a new op. - transform.annotate %1 "annotated" : !transform.any_op -} - -// ----- - -// CHECK-LABEL: func @update_tracked_op_mapping_region() -// CHECK: "test.container"() ({ -// CHECK: %0 = "test.foo"() {annotated} : () -> i32 -// CHECK: }) : () -> () -func.func @update_tracked_op_mapping_region() { - "test.container"() ({ - %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32) - }) : () -> () - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %0 { + transform.apply_patterns to %0 { transform.apply_patterns.transform.test_patterns } : !transform.any_op // Add an attribute to %1, which is now mapped to a new op. @@ -60,7 +38,9 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{tracking listener failed to find replacement op}} - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op // %1 must be used in some way. If no replacement payload op could be found, // an error is thrown only if the handle is not dead. transform.annotate %1 "annotated" : !transform.any_op @@ -84,7 +64,9 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // No error because %1 is dead. - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op } // ----- @@ -104,7 +86,9 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} {fail_on_payload_replacement_not_found = false}: !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } {fail_on_payload_replacement_not_found = false} : !transform.any_op transform.annotate %1 "annotated" : !transform.any_op } @@ -120,7 +104,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op } // ----- @@ -142,7 +128,9 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op } @@ -162,7 +150,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %1 { + transform.apply_patterns to %1 { transform.apply_patterns.canonicalization } : !transform.any_op transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -30,35 +30,35 @@ // TODO: group these lower-level controls into various properly named vector // lowering TD macros. - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -167,7 +167,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_broadcast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -210,7 +210,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir @@ -300,7 +300,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -48,11 +48,11 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "matmulintrinsics" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -347,7 +347,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -56,7 +56,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -96,7 +96,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_masks } : !transform.any_op } @@ -127,7 +127,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_masked_transfers } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -267,7 +267,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -190,7 +190,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -140,11 +140,11 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_outerproduct } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_broadcast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -154,7 +154,7 @@ %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -17,7 +17,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -39,7 +39,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -63,7 +63,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -87,7 +87,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -111,7 +111,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -135,7 +135,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -108,7 +108,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } @@ -169,7 +169,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } @@ -237,7 +237,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -103,7 +103,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -161,7 +161,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -223,7 +223,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -265,7 +265,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -240,7 +240,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op @@ -362,7 +362,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -76,7 +76,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise" } : !transform.any_op } @@ -99,7 +99,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } @@ -118,7 +118,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" } : !transform.any_op } @@ -605,7 +605,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true } : !transform.any_op } @@ -683,7 +683,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } @@ -762,7 +762,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -784,10 +784,6 @@ } namespace { -void populateTestPatterns(RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); -} - /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL /// types for operands and results. @@ -825,11 +821,6 @@ constraints.try_emplace("verbose_constraint", verboseConstraint); hooks.mergeInPDLMatchHooks(std::move(constraints)); }); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("transform.test", populateTestPatterns); - }); } }; } // namespace