diff --git a/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.h b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.h @@ -0,0 +1,122 @@ +//===- ApplyPatternsExtension.h - Transform Dialect Extension ---------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRANSFORMEXTENSIONS_APPLYPATTERNSEXTENSION_H_ +#define TRANSFORMEXTENSIONS_APPLYPATTERNSEXTENSION_H_ + +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +class DialectRegistry; + +namespace func { +class FuncOp; +} // namespace func + +namespace scf { +class ForallOp; +} // namespace scf + +namespace transform { +// Types needed for builders. +struct TileSizesSpec; +struct NumThreadsSpec; +class TransformTypeInterface; + +/// Selected patterns for ApplyPatternOp. +struct ApplyPatternsOpPatterns { + bool additionalPatterns = false; + bool bubbleCollapse = false; + bool bubbleExpand = false; + bool bubblePackUnPack = false; + bool canonicalization = false; + bool eraseUnnecessaryTensorOperands = false; + bool expandMemrefStridedMetadata = false; + bool extractAddressComputations = false; + bool foldMemrefAliases = false; + bool foldReassociativeReshapes = false; + bool foldTensorEmptyExtract = false; + bool foldTensorSubsets = false; + bool linalgElementwiseGreedyFusion = false; + bool lowerTransferOpPermutations = false; + bool lowerVectorMasks = false; + bool prepareVectorToMma = false; + bool rankReducingLinalg = false; + bool rankReducingLinalgViaReshapes = false; + bool rankReducingVector = false; + bool swapPaddingElideConditional = false; + bool swappingPatterns = false; + bool tilingCanonicalization = false; + bool unrollVectorsGpuMmaSync = false; + bool unrollVectorsGpuWmma = false; +}; + +void registerApplyPatternsExtensionTransformDialectExtension( + DialectRegistry ®istry); + +class ErrorCheckingTrackingListener : public tensor::TrackingListener { +public: + using tensor::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert((errorStateChecked || !hadErrors) && + "must check listener error state"); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + + DiagnosedSilenceableFailure check(Location loc) { + if (failed(checkErrorState())) + return emitDefiniteFailure(loc, "listener failed"); + return DiagnosedSilenceableFailure::success(); + } + DiagnosedSilenceableFailure check(Location loc, + DiagnosedSilenceableFailure &&diag) { + if (failed(checkErrorState())) { + auto definite = emitDefiniteFailure(loc, "listener failed"); + if (diag.isSilenceableFailure()) { + definite.attachNote() + << "was propagating silenceable error:" << diag.getMessage(); + (void)diag.silence(); + } + return definite; + } + return std::move(diag); + } + + LogicalResult checkErrorState() const { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = true; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return failure(hadErrors); + } + +private: + void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) override; + + bool hadErrors = false; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + mutable bool errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +}; + +} // namespace transform + +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.h.inc" + +#endif // TRANSFORMEXTENSIONS_APPLYPATTERNSEXTENSION_H_ \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.td b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.td @@ -0,0 +1,142 @@ +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +// TODO: Some bitvector to scale better than n-bools. +def ApplyPatternsOp : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + Greedily applies patterns as specified by its attributes. + + Must be applied to an op with trait IsolatedFromAbove since the + GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter + to preserve handles to payload operations nested within operations + associated with `target`. Fails if tracking cannot find replacement for a + payload operation. This may become controllable with an attribute in the + future. + + Returns the IsolatedFromAbove op whose content it has modified for better + chaining APIs. + + The following additive attributes can be set, they add patterns in an + unspecified order: + - additional_patterns: Extra patterns we shortcut into the system; currently + `generate-to-constant`. + - bubble_collapse: bubble `collapse_shape` down across Linalg ops. This + must be applied separately from `bubble_expand` patterns because of some + upstream pattern interference issue atm. + - bubble_expand: bubble `expand_shape` down across Linalg ops. This + must be applied separately from `bubble_collapse` patterns because of some + upstream pattern interference issue atm. + - bubble_pack_un_pack: bubble `pack` up and `unpack` down across Linalg + ops. + - canonicalization: adds all the canonicalization patterns of all + registered dialects and ops. + - erase_unnecessary_tensor_operands: add patterns that erase unnecessary + tensor operands. + - expand_memref_strided_metadata: adds patterns that expand memref + operations into extract_strided_metadata operations and a materialization + of their effect on the metadata (sizes, offset, strides). + - extract_address_computations: adds patterns for anchoring subview + accessing operations at [0, ... 0]. + - fold_memref_aliases: adds patterns for folding ops such as + memref.subview. + - fold_reassociative_reshapes: adds patterns that fold insert_slice/ + extract_slice ops with reassociative reshape ops. + - fold_tensor_empty_extract: Fold tensor.empty used by extract_slice in + case it is the only use of extract. + - fold_tensor_subsets: adds patterns for folding tensor subset ops into + their producer and consumers. + - linalg_elementwise_greedy_fusion: add linalg elementwise ops fusion + patterns using a naive default heuristic. + - lower_transfer_op_permutations: Lower transfer ops to transfer ops + with minor identity permutations. + - lower_vector_masks: Lower vector.mask ops away. + - prepare_vector_to_mma: pre-process vector.contract op to set it in a form + that can be mapped to nvgpu.mma operations. + - rank_reducing_linalg: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using insert/extract slices. + - rank_reducing_linalg_via_reshapes: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using expand/collapse shape ops. + - rank_reducing_vector: adds patterns that results in rank-reducing + behavior on subset-based vector operations. + adopts the upstream version. + - swapping_patterns: adds patterns that swap operations for a better outcome. + This is a catch all that can be refined further if/when needed. + - swap_padding_elide_conditional: refines the tensor.pad + + tensor.extract_slice swapping pattern. This injects static information + that guarantees padding is smaller than the window size which guarantees + we never see a tile comprised of padding-only. + - tiling_canonicalization: adds specific tiling-related canonicalization + patterns. + - unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile + size for GPUs with mma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + - unroll_vectors_gpu_wmma: adds patterns that unroll vectors to a native tile + size for GPUs with wmma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + + + #### Return modes: + + This operation applies a set of patterns specified by attributes. To apply + these patterns, this operation must target an operation that is isolated + from above, otherwise the transform definitely fails. + + If the pattern application fails, or if the underlying listener fails to + capture op handles, the transformation definitely fails. + + Otherwise the transformation is successful. + + This operation does not consume the target handle and does not produce any + handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$additional_patterns, + UnitAttr:$bubble_collapse, + UnitAttr:$bubble_expand, + UnitAttr:$bubble_pack_un_pack, + UnitAttr:$canonicalization, + UnitAttr:$erase_unnecessary_tensor_operands, + UnitAttr:$expand_memref_strided_metadata, + UnitAttr:$extract_address_computations, + UnitAttr:$fold_memref_aliases, + UnitAttr:$fold_reassociative_reshapes, + UnitAttr:$fold_tensor_empty_extract, + UnitAttr:$fold_tensor_subsets, + UnitAttr:$licm, + UnitAttr:$linalg_elementwise_greedy_fusion, + UnitAttr:$lower_transfer_op_permutations, + UnitAttr:$lower_vector_masks, + UnitAttr:$prepare_vector_to_mma, + UnitAttr:$rank_reducing_linalg, + UnitAttr:$rank_reducing_linalg_via_reshapes, + UnitAttr:$rank_reducing_vector, + UnitAttr:$swap_padding_elide_conditional, + UnitAttr:$swapping_patterns, + UnitAttr:$tiling_canonicalization); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target, "const ApplyPatternsOpPatterns &":$patterns)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS ApplyPatternsExtensionOps.td) +mlir_tablegen(ApplyPatternsExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(ApplyPatternsExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformApplyPatternsExtensionOpsIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(ApplyPatternsExtension) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -74,6 +74,7 @@ #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" @@ -134,6 +135,7 @@ memref::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); + transform::registerApplyPatternsExtensionTransformDialectExtension(registry); vector::registerTransformDialectExtension(registry); // Register all external models. diff --git a/mlir/lib/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.cpp b/mlir/lib/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.cpp @@ -0,0 +1,467 @@ +//===- ApplyPatternsExtension.cpp - Transform Dialect Extension +//-----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtension.h" + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "common-ext" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +using namespace mlir; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// ErrorCheckingTrackingListener +//===----------------------------------------------------------------------===// + +void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values) { + // Certain ops can dropped safely. + if (isa(op)) { + LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); + return; + } + + hadErrors = true; +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +} + +// Return true if all the uses of op are either Store/transfer_write. +// There can be SubviewOp users as long as all its users are also +// StoreOp/transfer_write. If return true it also fills out the uses, if it +// returns false uses is unchanged. +static bool allUsesAreStores(Operation *op, std::vector &uses) { + std::vector opUses; + for (OpOperand &use : op->getUses()) { + Operation *useOp = use.getOwner(); + if (isa( + useOp) || + (isa(useOp) && allUsesAreStores(useOp, opUses))) { + opUses.push_back(useOp); + continue; + } + return false; + } + uses.insert(uses.end(), opUses.begin(), opUses.end()); + return true; +} + +// Track temporary allocations that are never read from. If this is the case +// it means both the allocations and associated stores can be removed. +static void eraseDeadAllocAndStores(RewriterBase &rewriter, + Operation *parentOp) { + std::vector opToErase; + parentOp->walk([&](memref::AllocOp op) { + if (allUsesAreStores(op, opToErase)) { + opToErase.push_back(op.getOperation()); + } + }); + for (Operation *op : opToErase) + rewriter.eraseOp(op); +} + +//===---------------------------------------------------------------------===// +// ApplyPatternsOp +//===---------------------------------------------------------------------===// +void transform::ApplyPatternsOp::build( + OpBuilder &builder, OperationState &result, Value target, + const ApplyPatternsOpPatterns &patterns) { + result.addOperands(target); + + auto unitAttr = builder.getUnitAttr(); + +#define ADD_PATTERN(NAME, ATTR) \ + if (patterns.NAME) \ + result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr); + /// + /// When touching something here, do not forget to update + /// ApplyPatternsExtension.h. + /// + ADD_PATTERN(additionalPatterns, getAdditionalPatternsAttrName) + ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName) + ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName) + ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName) + ADD_PATTERN(canonicalization, getCanonicalizationAttrName) + ADD_PATTERN(eraseUnnecessaryTensorOperands, + getEraseUnnecessaryTensorOperandsAttrName) + ADD_PATTERN(expandMemrefStridedMetadata, + getExpandMemrefStridedMetadataAttrName) + ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName) + ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName) + ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName) + ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName) + ADD_PATTERN(foldTensorSubsets, getFoldTensorSubsetsAttrName) + ADD_PATTERN(linalgElementwiseGreedyFusion, + getLinalgElementwiseGreedyFusionAttrName) + ADD_PATTERN(lowerTransferOpPermutations, + getLowerTransferOpPermutationsAttrName) + ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName) + ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName) + ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName) + ADD_PATTERN(rankReducingLinalgViaReshapes, + getRankReducingLinalgViaReshapesAttrName) + ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName) + ADD_PATTERN(swapPaddingElideConditional, + getSwapPaddingElideConditionalAttrName) + ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName) + ADD_PATTERN(tilingCanonicalization, getTilingCanonicalizationAttrName) +#undef ADD_PATTERN +} + +static void addOperands(Operation *op, SetVector &operandSet) { + if (!op) + return; + TypeSwitch(op) + .Case([&](linalg::LinalgOp linalgOp) { + SmallVector inputOperands{linalgOp.getDpsInputOperands()}; + operandSet.insert(inputOperands.begin(), inputOperands.end()); + }) + .Default([&](Operation *operation) { + operandSet.insert(operation->operand_begin(), operation->operand_end()); + }); +} + +template +static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + Operation *consumer = fusedOperand->getOwner(); + SetVector fusedOpOperands; + if (producer->getNumResults() != 1) + return false; + addOperands(consumer, fusedOpOperands); + fusedOpOperands.remove(producer->getResult(0)); + addOperands(producer, fusedOpOperands); + return fusedOpOperands.size() <= limit; +} + +namespace { +/// Rewrite a tensor.generate as an arith.constant when possible. +struct GenerateToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::GenerateOp generateOp, + PatternRewriter &rewriter) const final { + auto tensorType = generateOp.getResult().getType().cast(); + if (!tensorType.hasStaticShape()) + return failure(); + auto terminatorOp = + cast(generateOp.getBody().front().getTerminator()); + if (terminatorOp->getNumOperands() > 1) + return failure(); + auto constantOp = + terminatorOp->getOperand(0).getDefiningOp(); + if (!constantOp) + return failure(); + rewriter.replaceOpWithNewOp( + generateOp, tensorType, + DenseElementsAttr::get(tensorType, constantOp.getValueAttr())); + return success(); + } +}; + +/// Fold tensor.empty used by extract_slice if this the only use of +/// extract_slice and the result is static. +struct FoldTensorEmptyExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const final { + auto tensorEmpty = extractOp.getSource().getDefiningOp(); + if (!tensorEmpty || !extractOp.getType().hasStaticShape() || + !tensorEmpty->hasOneUse()) + return failure(); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType().getShape(), + extractOp.getType().getElementType()); + return success(); + } +}; + +/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into +/// `linalg.fill(cst, empty)` when the padding constant and the fill constant +/// are the same. +/// This seems generally desirable as a folding but may be too intrusive, so we +/// only apply it selectively for now. +// TODO: atm hardcoded on linalg.fill but we could take any result of any +// generic that yields a constant in that result. +struct FoldFillIntoPad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const final { + Operation *currentOp = padOp.getSource().getDefiningOp(); + auto maybeExtractSlice = + dyn_cast_or_null(currentOp); + while (currentOp && maybeExtractSlice) { + currentOp = maybeExtractSlice.getSource().getDefiningOp(); + maybeExtractSlice = dyn_cast_or_null(currentOp); + } + auto fillOp = dyn_cast_or_null(currentOp); + if (!fillOp) { + return rewriter.notifyMatchFailure( + padOp, "not coming from a linalg.fill op via tensor.extract_slice*"); + } + + Value padValue = padOp.getConstantPaddingValue(); + RankedTensorType resultType = padOp.getResultType(); + if (!padValue || + getAsOpFoldResult(padValue) != + getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) { + return rewriter.notifyMatchFailure( + padOp, "not a constant value matching the fill value"); + } + + Location loc = padOp.getLoc(); + auto emptyOp = rewriter.create( + loc, resultType, + linalg::createDynamicDimensions(rewriter, loc, padOp.getResult())); + rewriter.replaceOpWithNewOp(padOp, padValue, + emptyOp.getResult()); + + return success(); + } +}; +} // namespace + +static void +addLowerTransferOpPermutationsPatterns(RewritePatternSet &patterns) { + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); +} + +static void addLowerVectorMasksPatterns(RewritePatternSet &patterns) { + vector::populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); +} + +static void addExtractAddressComputationsPatterns(RewritePatternSet &patterns) { + memref::populateExtractAddressComputationsPatterns(patterns); +} + +static void addFoldMemrefAliasPatterns(RewritePatternSet &patterns) { + memref::populateFoldMemRefAliasOpPatterns(patterns); +} + +static void addFoldTensorEmptyExtract(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void addReassociativeReshapePatterns(RewritePatternSet &patterns) { + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +static void addFoldTensorSubsetsPatterns(RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetOpPatterns(patterns); + // TODO: upstream should move these to populateFoldTensorSubsetOpPatterns. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); +} + +static void +addEraseUnnecessaryTensorOperandsPatterns(RewritePatternSet &patterns) { + linalg::populateEraseUnnecessaryInputsPatterns(patterns); +} + +static void addPrepareVectorToMmaPatterns(RewritePatternSet &patterns) { + populatePrepareVectorToMMAPatterns(patterns, /*useNvGpu=*/true); +} + +static void addRankReducingLinalgPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); +} + +static void +addRankReducingLinalgViaReshapesPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); +} + +static void addRankReducingVectorPatterns(RewritePatternSet &patterns) { + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); +} + +static void addSwappingPatterns(RewritePatternSet &patterns, + bool swapPaddingElideCornerCase) { + patterns.add( + patterns.getContext(), + [&](tensor::ExtractSliceOp) -> std::optional { + return !swapPaddingElideCornerCase; + }); +} + +static void addTilingCanonicalizationPatterns(RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + /// This seems generally desirable as a folding but may be too intrusive, so + /// we only apply it selectively for now. + patterns.add(patterns.getContext()); +} + +static void addAdditionalPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void +addAllRegisteredCanonicalizationPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); +} + +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure( + target, + "applies only to isolated-from-above targets because it needs to apply " + "patterns greedily"); + } + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + if (getAdditionalPatterns()) + addAdditionalPatterns(patterns); + if (getBubbleCollapse()) { + linalg::populateFoldReshapeOpsByCollapsingPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubbleExpand()) { + linalg::populateFoldReshapeOpsByExpansionPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubblePackUnPack()) + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + if (getCanonicalization()) + addAllRegisteredCanonicalizationPatterns(patterns); + if (getEraseUnnecessaryTensorOperands()) + addEraseUnnecessaryTensorOperandsPatterns(patterns); + if (getExpandMemrefStridedMetadata()) + memref::populateExpandStridedMetadataPatterns(patterns); + if (getExtractAddressComputations()) + addExtractAddressComputationsPatterns(patterns); + if (getFoldMemrefAliases()) + addFoldMemrefAliasPatterns(patterns); + if (getFoldReassociativeReshapes()) + addReassociativeReshapePatterns(patterns); + if (getFoldTensorEmptyExtract()) + addFoldTensorEmptyExtract(patterns); + if (getFoldTensorSubsets()) + addFoldTensorSubsetsPatterns(patterns); + if (getLinalgElementwiseGreedyFusion()) + linalg::populateElementwiseOpsFusionPatterns(patterns, + setFusedOpOperandLimit<3>); + if (getLowerTransferOpPermutations()) + addLowerTransferOpPermutationsPatterns(patterns); + if (getLowerVectorMasks()) + addLowerVectorMasksPatterns(patterns); + if (getPrepareVectorToMma()) + addPrepareVectorToMmaPatterns(patterns); + if (getRankReducingLinalg()) + addRankReducingLinalgPatterns(patterns); + if (getRankReducingLinalgViaReshapes()) + addRankReducingLinalgViaReshapesPatterns(patterns); + if (getRankReducingVector()) + addRankReducingVectorPatterns(patterns); + if (getSwappingPatterns()) + addSwappingPatterns(patterns, getSwapPaddingElideConditional()); + if (getTilingCanonicalization()) + addTilingCanonicalizationPatterns(patterns); + + Location loc = target->getLoc(); + ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + // Manually gather list of ops because the other GreedyPatternRewriteDriver + // overloads only accepts ops that are isolated from above. + SmallVector ops; + target->walk([&](Operation *nestedOp) { + if (target != nestedOp) + ops.push_back(nestedOp); + }); + LogicalResult result = + applyOpPatternsAndFold(ops, std::move(patterns), config); + if (failed(result)) { + return listener.check( + loc, mlir::emitDefiniteFailure(target, "greedy patterns failed")); + } + + auto diag = listener.check(loc); + if (!diag.succeeded()) + return diag; + + return listener.check(loc); +} + +void transform::ApplyPatternsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. +class ApplyPatternsExtensionTransformDialectExtension + : public transform::TransformDialectExtension< + ApplyPatternsExtensionTransformDialectExtension> { +public: + ApplyPatternsExtensionTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/ApplyPatternsExtension/ApplyPatternsExtensionOps.cpp.inc" + +void mlir::transform::registerApplyPatternsExtensionTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/ApplyPatternsExtension/CMakeLists.txt @@ -0,0 +1,37 @@ +add_mlir_dialect_library(ApplyPatternsExtension + ApplyPatternsExtension.cpp + + DEPENDS + MLIRTransformApplyPatternsExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRAffineUtils + MLIRAnalysis + MLIRArithDialect + MLIRArithUtils + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRGPUOps + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransformOps + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRMemRefTransforms + MLIRPDLDialect + MLIRPass + MLIRSCFDialect + MLIRSCFTransforms + MLIRSCFUtils + MLIRTensorDialect + MLIRTensorTransformOps + MLIRTensorTransforms + MLIRTransformDialect + MLIRTransforms + MLIRVectorDialect + MLIRVectorToGPU + MLIRVectorTransforms + MLIRTensorTransformOps + ) \ No newline at end of file diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(ApplyPatternsExtension) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/test/Dialect/Transform/transform-op-apply-patterns.mlir b/mlir/test/Dialect/Transform/transform-op-apply-patterns.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/transform-op-apply-patterns.mlir @@ -0,0 +1,200 @@ +// RUN: mlir-opt -split-input-file \ +// RUN: -test-transform-dialect-interpreter -canonicalize \ +// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @select_cmp_eq_select +// CHECK: return %arg1 +func.func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = arith.cmpi eq, %arg0, %arg1 : i64 + %1 = arith.select %0, %arg0, %arg1 : i64 + return %1 : i64 +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 : !pdl.operation failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { canonicalization } : (!pdl.operation) -> () + } +} + +// ----- + +#map2 = affine_map<(d0, d1) -> (d0, d1)> + +func.func private @mutate(f32) -> f32 + +// CHECK-LABEL: @bubble_up +func.func @bubble_up(%arg0: tensor<32x64xf32>) -> tensor<32x2x32xf32> { + // Check that shape expansion precedes linalg.generic after the patterns were applied. + // CHECK: tensor.expand_shape + // CHECK: tensor.expand_shape + // CHECK: linalg.generic + %init = tensor.empty() : tensor<32x64xf32> + %result = linalg.generic { + indexing_maps = [#map2, #map2], + iterator_types = ["parallel", "parallel"]} + ins(%arg0: tensor<32x64xf32>) outs(%init: tensor<32x64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %0 = func.call @mutate(%arg1) : (f32) -> f32 + linalg.yield %0 : f32 + } -> tensor<32x64xf32> + %out = tensor.expand_shape %result[[0], [1, 2]] : tensor<32x64xf32> into tensor<32x2x32xf32> + return %out : tensor<32x2x32xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { bubble_expand } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_fill_to_fill +func.func @pad_fill_to_fill(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %padded = tensor.pad %fill low[%c0, %c0] high[%c1, %c2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<31x62xf32> to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_fill_different_ssa_value_but_same_cst +func.func @pad_fill_different_ssa_value_but_same_cst(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the constant comes from different ssa value. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %padded = tensor.pad %fill low[%c0, %c0] high[%c1, %c2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor<31x62xf32> to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_fill_to_fill +func.func @pad_extract_fill_to_fill(%arg0: tensor<31x62xf32>, + %size0 : index, %size1 : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the fill is hidden behind an extract_slice. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %extracted_slice = tensor.extract_slice %fill[0, 0] [%size0, %size1] [1, 1] : tensor<31x62xf32> to tensor + %padded = tensor.pad %extracted_slice low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_extract_fill_to_fill +func.func @pad_extract_extract_fill_to_fill(%arg0: tensor<31x62xf32>, + %size0a : index, %size1a : index, + %size0b : index, %size1b : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the fill is hidden behind a few `extract_slice`s. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %extracted_sliceA = tensor.extract_slice %fill[0, 0] [%size0a, %size1a] [1, 1] : tensor<31x62xf32> to tensor + %extracted_sliceB = tensor.extract_slice %extracted_sliceA[0, 0] [%size0b, %size1b] [1, 1] : tensor to tensor + %padded = tensor.pad %extracted_sliceB low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_bigger_fill_to_fill +func.func @pad_extract_bigger_fill_to_fill(%arg0: tensor<253x123xf32>, + %size0 : index, %size1 : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a bigger fill with the same constant is replaced by a + // fill of the right size. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<253x123xf32>) -> tensor<253x123xf32> + %extracted_slice = tensor.extract_slice %fill[0, 0] [%size0, %size1] [1, 1] : tensor<253x123xf32> to tensor + %padded = tensor.pad %extracted_slice low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} \ No newline at end of file