diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(CommonExtensions) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Transform/CommonExtensions/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS CommonExtensionsOps.td) +mlir_tablegen(CommonExtensionsOps.h.inc -gen-op-decls) +mlir_tablegen(CommonExtensionsOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformCommonExtensionsOpsIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensions.h b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensions.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensions.h @@ -0,0 +1,123 @@ +//===- CommonExtensions.h - Transform Dialect Extension ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS_H_ +#define COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS_H_ + +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +class DialectRegistry; + +namespace func { +class FuncOp; +} // namespace func + +namespace scf { +class ForallOp; +} // namespace scf + +namespace transform { +// Types needed for builders. +struct TileSizesSpec; +struct NumThreadsSpec; +class TransformTypeInterface; + +/// Selected patterns for ApplyPatternOp. +struct ApplyPatternsOpPatterns { + bool additionalPatterns = false; + bool bubbleCollapse = false; + bool bubbleExpand = false; + bool bubblePackUnPack = false; + bool canonicalization = false; + bool cse = false; + bool eraseUnnecessaryTensorOperands = false; + bool expandMemrefStridedMetadata = false; + bool extractAddressComputations = false; + bool foldMemrefAliases = false; + bool foldReassociativeReshapes = false; + bool foldTensorEmptyExtract = false; + bool foldTensorSubsets = false; + bool licm = false; + bool linalgElementwiseGreedyFusion = false; + bool lowerTransferOpPermutations = false; + bool lowerVectorMasks = false; + bool prepareVectorToMma = false; + bool rankReducingLinalg = false; + bool rankReducingLinalgViaReshapes = false; + bool rankReducingVector = false; + bool swapPaddingElideConditional = false; + bool swappingPatterns = false; + bool tilingCanonicalization = false; + bool unrollVectorsGpuMmaSync = false; + bool unrollVectorsGpuWmma = false; +}; + +void registerCommonExtensionsTransformDialectExtension( + DialectRegistry ®istry); + +class ErrorCheckingTrackingListener : public tensor::TrackingListener { +public: + using tensor::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert((errorStateChecked || !hadErrors) && + "must check listener error state"); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + + DiagnosedSilenceableFailure check(Location loc) { + if (failed(checkErrorState())) + return emitDefiniteFailure(loc, "listener failed"); + return DiagnosedSilenceableFailure::success(); + } + DiagnosedSilenceableFailure check(Location loc, + DiagnosedSilenceableFailure &&diag) { + if (failed(checkErrorState())) { + auto definite = emitDefiniteFailure(loc, "listener failed"); + if (diag.isSilenceableFailure()) { + definite.attachNote() + << "was propagating silenceable error:" << diag.getMessage(); + (void)diag.silence(); + } + return definite; + } + return std::move(diag); + } + + LogicalResult checkErrorState() const { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = true; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return failure(hadErrors); + } + +private: + void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) override; + + bool hadErrors = false; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + mutable bool errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +}; + +} // namespace transform + +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.h.inc" + +#endif // COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS_H_ \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.td b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.td @@ -0,0 +1,155 @@ +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +def ApplyPatternsOp : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + Greedily applies patterns as specified by its attributes. + + Must be applied to an op with trait IsolatedFromAbove since the + GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter + to preserve handles to payload operations nested within operations + associated with `target`. Fails if tracking cannot find replacement for a + payload operation. This may become controllable with an attribute in the + future. + + Returns the IsolatedFromAbove op whose content it has modified for better + chaining APIs. + + The following additive attributes can be set, they add patterns in an + unspecified order: + - additional_patterns: fancy patterns we shortcut into the system, + will need to be sliced out better in the future. + - bubble_collapse: bubble `collapse_shape` down across Linalg ops. This + must be applied separately from `bubble_expand` patterns because of some + upstream pattern interference issue atm. + - bubble_expand: bubble `expand_shape` down across Linalg ops. This + must be applied separately from `bubble_collapse` patterns because of some + upstream pattern interference issue atm. + - bubble_pack_un_pack: bubble `pack` up and `unpack` down across Linalg + ops. + - canonicalization: adds all the canonicalization patterns of all + registered dialects and ops. + - cse: additionally apply common subexpression elimination. This must + apply on a funcOp. This is not a set of patterns per se but is still very + convenient to apply it close to canonicalization and other greedy pattern + applications. + - erase_unnecessary_tensor_operands: add patterns that erase unnecessary + tensor operands. + - expand_memref_strided_metadata: adds patterns that expand memref + operations into extract_strided_metadata operations and a materialization + of their effect on the metadata (sizes, offset, strides). + - extract_address_computations: adds patterns for anchoring subview + accessing operations at [0, ... 0]. + - fold_memref_aliases: adds patterns for folding ops such as + memref.subview. + - fold_reassociative_reshapes: adds patterns that fold insert_slice/ + extract_slice ops with reassociative reshape ops. + - fold_tensor_empty_extract: Fold tensor.empty used by extract_slice in + case it is the only use of extract. + - fold_tensor_subsets: adds patterns for folding tensor subset ops into + their producer and consumers. + - licm: additionally apply loop-independent code motion and single + iteration loop promotion. This is not a set of patterns per se but is still + very convenient to apply it close to canonicalization and other greedy + pattern applications. + - linalg_elementwise_greedy_fusion: add linalg elementwise ops fusion + patterns using a naive default heuristic. + - lower_transfer_op_permutations: Lower transfer ops to transfer ops + with minor identity permutations. + - lower_vector_masks: Lower vector.mask ops away. + - prepare_vector_to_mma: pre-process vector.contract op to set it in a form + that can be mapped to nvgpu.mma operations. + - rank_reducing_linalg: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using insert/extract slices. + - rank_reducing_linalg_via_reshapes: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using expand/collapse shape ops. + - rank_reducing_vector: adds patterns that results in rank-reducing + behavior on subset-based vector operations. + adopts the upstream version. + - swapping_patterns: adds patterns that swap operations for a better outcome. + This is a catch all that can be refined further if/when needed. + - swap_padding_elide_conditional: refines the tensor.pad + + tensor.extract_slice swapping pattern. This injects static information + that guarantees padding is smaller than the window size which guarantees + we never see a tile comprised of padding-only. + - tiling_canonicalization: adds specific tiling-related canonicalization + patterns. + - unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile + size for GPUs with mma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + - unroll_vectors_gpu_wmma: adds patterns that unroll vectors to a native tile + size for GPUs with wmma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + + + #### Return modes: + + This operation applies a set of patterns specified by attributes. To apply + these patterns, this operation must target an operation that is isolated + from above, otherwise the transform definitely fails. + + If the pattern application fails, or if the underlying listener fails to + capture op handles, the transformation definitely fails. + + Otherwise the transformation is successful. + + This operation does not consume the target handle and does not produce any + handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$additional_patterns, + UnitAttr:$bubble_collapse, + UnitAttr:$bubble_expand, + UnitAttr:$bubble_pack_un_pack, + UnitAttr:$canonicalization, + UnitAttr:$cse, + UnitAttr:$erase_unnecessary_tensor_operands, + UnitAttr:$expand_memref_strided_metadata, + UnitAttr:$extract_address_computations, + UnitAttr:$fold_memref_aliases, + UnitAttr:$fold_reassociative_reshapes, + UnitAttr:$fold_tensor_empty_extract, + UnitAttr:$fold_tensor_subsets, + UnitAttr:$licm, + UnitAttr:$linalg_elementwise_greedy_fusion, + UnitAttr:$lower_transfer_op_permutations, + UnitAttr:$lower_vector_masks, + UnitAttr:$prepare_vector_to_mma, + UnitAttr:$rank_reducing_linalg, + UnitAttr:$rank_reducing_linalg_via_reshapes, + UnitAttr:$rank_reducing_vector, + UnitAttr:$swap_padding_elide_conditional, + UnitAttr:$swapping_patterns, + UnitAttr:$tiling_canonicalization, + UnitAttr:$unroll_vectors_gpu_mma_sync, + UnitAttr:$unroll_vectors_gpu_wmma); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + let cppNamespace = "mlir::transform"; + + let builders = [ + // TODO: Some bitvector to scale better than n-bools. + OpBuilder<(ins "Value":$target, + "const ApplyPatternsOpPatterns &":$patterns)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(CommonExtensions) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/CommonExtensions/CMakeLists.txt b/mlir/lib/Dialect/Transform/CommonExtensions/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/CommonExtensions/CMakeLists.txt @@ -0,0 +1,37 @@ +add_mlir_library(CommonExtensions + CommonExtensions.cpp + + DEPENDS + MLIRTransformCommonExtensionsOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRAffineUtils + MLIRAnalysis + MLIRArithDialect + MLIRArithUtils + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRGPUOps + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransformOps + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRMemRefTransforms + MLIRPDLDialect + MLIRPass + MLIRSCFDialect + MLIRSCFTransforms + MLIRSCFUtils + MLIRTensorDialect + MLIRTensorTransformOps + MLIRTensorTransforms + MLIRTransformDialect + MLIRTransforms + MLIRVectorDialect + MLIRVectorToGPU + MLIRVectorTransforms + MLIRTensorTransformOps + ) \ No newline at end of file diff --git a/mlir/lib/Dialect/Transform/CommonExtensions/CommonExtensions.cpp b/mlir/lib/Dialect/Transform/CommonExtensions/CommonExtensions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/CommonExtensions/CommonExtensions.cpp @@ -0,0 +1,559 @@ +//===- CommonExtensions.cpp - Transform Dialect Extension -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/CommonExtensions/CommonExtensions.h" + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "common-ext" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +using namespace mlir; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// ErrorCheckingTrackingListener +//===----------------------------------------------------------------------===// + +void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values) { + // Certain ops can dropped safely. + if (isa(op)) { + LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); + return; + } + + hadErrors = true; +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +} + +// Return true if all the uses of op are either Store/transfer_write. +// There can be SubviewOp users as long as all its users are also +// StoreOp/transfer_write. If return true it also fills out the uses, if it +// returns false uses is unchanged. +static bool allUsesAreStores(Operation *op, std::vector &uses) { + std::vector opUses; + for (OpOperand &use : op->getUses()) { + Operation *useOp = use.getOwner(); + if (isa( + useOp) || + (isa(useOp) && allUsesAreStores(useOp, opUses))) { + opUses.push_back(useOp); + continue; + } + return false; + } + uses.insert(uses.end(), opUses.begin(), opUses.end()); + return true; +} + +// Track temporary allocations that are never read from. If this is the case +// it means both the allocations and associated stores can be removed. +static void eraseDeadAllocAndStores(RewriterBase &rewriter, + Operation *parentOp) { + std::vector opToErase; + parentOp->walk([&](memref::AllocOp op) { + if (allUsesAreStores(op, opToErase)) { + opToErase.push_back(op.getOperation()); + } + }); + for (Operation *op : opToErase) + rewriter.eraseOp(op); +} + +//===---------------------------------------------------------------------===// +// ApplyPatternsOp +//===---------------------------------------------------------------------===// +void transform::ApplyPatternsOp::build( + OpBuilder &builder, OperationState &result, Value target, + const ApplyPatternsOpPatterns &patterns) { + result.addOperands(target); + + auto unitAttr = builder.getUnitAttr(); + +#define ADD_PATTERN(NAME, ATTR) \ + if (patterns.NAME) \ + result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr); + /// + /// When touching something here, do not forget to update CommonExtensions.h. + /// + ADD_PATTERN(additionalPatterns, getAdditionalPatternsAttrName) + ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName) + ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName) + ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName) + ADD_PATTERN(canonicalization, getCanonicalizationAttrName) + ADD_PATTERN(cse, getCseAttrName) + ADD_PATTERN(eraseUnnecessaryTensorOperands, + getEraseUnnecessaryTensorOperandsAttrName) + ADD_PATTERN(expandMemrefStridedMetadata, + getExpandMemrefStridedMetadataAttrName) + ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName) + ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName) + ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName) + ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName) + ADD_PATTERN(foldTensorSubsets, getFoldTensorSubsetsAttrName) + ADD_PATTERN(licm, getLicmAttrName) + ADD_PATTERN(linalgElementwiseGreedyFusion, + getLinalgElementwiseGreedyFusionAttrName) + ADD_PATTERN(lowerTransferOpPermutations, + getLowerTransferOpPermutationsAttrName) + ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName) + ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName) + ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName) + ADD_PATTERN(rankReducingLinalgViaReshapes, + getRankReducingLinalgViaReshapesAttrName) + ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName) + ADD_PATTERN(swapPaddingElideConditional, + getSwapPaddingElideConditionalAttrName) + ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName) + ADD_PATTERN(tilingCanonicalization, getTilingCanonicalizationAttrName) + ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName) + ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName) +#undef ADD_PATTERN +} + +static void addOperands(Operation *op, SetVector &operandSet) { + if (!op) + return; + TypeSwitch(op) + .Case([&](linalg::LinalgOp linalgOp) { + SmallVector inputOperands{linalgOp.getDpsInputOperands()}; + operandSet.insert(inputOperands.begin(), inputOperands.end()); + }) + .Default([&](Operation *operation) { + operandSet.insert(operation->operand_begin(), operation->operand_end()); + }); +} + +template +static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + Operation *consumer = fusedOperand->getOwner(); + SetVector fusedOpOperands; + if (producer->getNumResults() != 1) + return false; + addOperands(consumer, fusedOpOperands); + fusedOpOperands.remove(producer->getResult(0)); + addOperands(producer, fusedOpOperands); + return fusedOpOperands.size() <= limit; +} + +namespace { +/// Rewrite a tensor.generate as an arith.constant when possible. +struct GenerateToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::GenerateOp generateOp, + PatternRewriter &rewriter) const final { + auto tensorType = generateOp.getResult().getType().cast(); + if (!tensorType.hasStaticShape()) + return failure(); + auto terminatorOp = + cast(generateOp.getBody().front().getTerminator()); + if (terminatorOp->getNumOperands() > 1) + return failure(); + auto constantOp = + terminatorOp->getOperand(0).getDefiningOp(); + if (!constantOp) + return failure(); + rewriter.replaceOpWithNewOp( + generateOp, tensorType, + DenseElementsAttr::get(tensorType, constantOp.getValueAttr())); + return success(); + } +}; + +/// Fold tensor.empty used by extract_slice if this the only use of +/// extract_slice and the result is static. +struct FoldTensorEmptyExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const final { + auto tensorEmpty = extractOp.getSource().getDefiningOp(); + if (!tensorEmpty || !extractOp.getType().hasStaticShape() || + !tensorEmpty->hasOneUse()) + return failure(); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType().getShape(), + extractOp.getType().getElementType()); + return success(); + } +}; + +/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into +/// `linalg.fill(cst, empty)` when the padding constant and the fill constant +/// are the same. +/// This seems generally desirable as a folding but may be too intrusive, so we +/// only apply it selectively for now. +// TODO: atm hardcoded on linalg.fill but we could take any result of any +// generic that yields a constant in that result. +struct FoldFillIntoPad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const final { + Operation *currentOp = padOp.getSource().getDefiningOp(); + auto maybeExtractSlice = + dyn_cast_or_null(currentOp); + while (currentOp && maybeExtractSlice) { + currentOp = maybeExtractSlice.getSource().getDefiningOp(); + maybeExtractSlice = dyn_cast_or_null(currentOp); + } + auto fillOp = dyn_cast_or_null(currentOp); + if (!fillOp) { + return rewriter.notifyMatchFailure( + padOp, "not coming from a linalg.fill op via tensor.extract_slice*"); + } + + Value padValue = padOp.getConstantPaddingValue(); + RankedTensorType resultType = padOp.getResultType(); + if (!padValue || + getAsOpFoldResult(padValue) != + getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) { + return rewriter.notifyMatchFailure( + padOp, "not a constant value matching the fill value"); + } + + Location loc = padOp.getLoc(); + auto emptyOp = rewriter.create( + loc, resultType, + linalg::createDynamicDimensions(rewriter, loc, padOp.getResult())); + rewriter.replaceOpWithNewOp(padOp, padValue, + emptyOp.getResult()); + + return success(); + } +}; +} // namespace + +static void +addLowerTransferOpPermutationsPatterns(RewritePatternSet &patterns) { + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); +} + +static void addLowerVectorMasksPatterns(RewritePatternSet &patterns) { + vector::populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); +} + +static void addExtractAddressComputationsPatterns(RewritePatternSet &patterns) { + memref::populateExtractAddressComputationsPatterns(patterns); +} + +static void addFoldMemrefAliasPatterns(RewritePatternSet &patterns) { + memref::populateFoldMemRefAliasOpPatterns(patterns); +} + +static void addFoldTensorEmptyExtract(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void addReassociativeReshapePatterns(RewritePatternSet &patterns) { + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +static void addFoldTensorSubsetsPatterns(RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetOpPatterns(patterns); + // TODO: upstream should move these to populateFoldTensorSubsetOpPatterns. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); +} + +static void +addEraseUnnecessaryTensorOperandsPatterns(RewritePatternSet &patterns) { + linalg::populateEraseUnnecessaryInputsPatterns(patterns); +} + +static void addPrepareVectorToMmaPatterns(RewritePatternSet &patterns) { + populatePrepareVectorToMMAPatterns(patterns, /*useNvGpu=*/true); +} + +static void addRankReducingLinalgPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); +} + +static void +addRankReducingLinalgViaReshapesPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); +} + +static void addRankReducingVectorPatterns(RewritePatternSet &patterns) { + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); +} + +static void addSwappingPatterns(RewritePatternSet &patterns, + bool swapPaddingElideCornerCase) { + patterns.add( + patterns.getContext(), + [&](tensor::ExtractSliceOp) -> std::optional { + return !swapPaddingElideCornerCase; + }); +} + +static void addTilingCanonicalizationPatterns(RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + /// This seems generally desirable as a folding but may be too intrusive, so + /// we only apply it selectively for now. + patterns.add(patterns.getContext()); +} + +// static std::optional> +// getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) { +// return getMmaNativeVectorSize(op); +// } +// +// static void addUnrollVectorsGpuMmaSyncPatterns(RewritePatternSet &patterns) { +// auto unrollOrder = [](Operation *op) -> std::optional> +// { +// auto contract = dyn_cast(op); +// if (!contract) +// return std::nullopt; +// return gpuMmaUnrollOrder(contract); +// }; +// vector::populateVectorUnrollPatterns( +// patterns, vector::UnrollVectorOptions() +// .setNativeShapeFn(getGPUTensorCoreNativeMmaSyncVectorSize) +// .setUnrollTraversalOrderFn(unrollOrder)); +// } +// +// static std::optional> +// getGPUTensorCoreNativeWmmaVectorSize(Operation *op) { +// return getWmmaNativeVectorSize(op); +// } +// +// static void addUnrollVectorsGpuWmmaPatterns(RewritePatternSet &patterns) { +// auto unrollOrder = [](Operation *op) -> std::optional> +// { +// auto contract = dyn_cast(op); +// if (!contract) +// return std::nullopt; +// return gpuMmaUnrollOrder(contract); +// }; +// vector::populateVectorUnrollPatterns( +// patterns, vector::UnrollVectorOptions() +// .setNativeShapeFn(getGPUTensorCoreNativeWmmaVectorSize) +// .setUnrollTraversalOrderFn(unrollOrder)); +// } + +static void addAdditionalPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void +addAllRegisteredCanonicalizationPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); +} + +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure( + target, + "applies only to isolated-from-above targets because it needs to apply " + "patterns greedily"); + } + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + if (getAdditionalPatterns()) + addAdditionalPatterns(patterns); + if (getBubbleCollapse()) { + linalg::populateFoldReshapeOpsByCollapsingPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubbleExpand()) { + linalg::populateFoldReshapeOpsByExpansionPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubblePackUnPack()) + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + if (getCanonicalization()) + addAllRegisteredCanonicalizationPatterns(patterns); + if (getEraseUnnecessaryTensorOperands()) + addEraseUnnecessaryTensorOperandsPatterns(patterns); + if (getExpandMemrefStridedMetadata()) + memref::populateExpandStridedMetadataPatterns(patterns); + if (getExtractAddressComputations()) + addExtractAddressComputationsPatterns(patterns); + if (getFoldMemrefAliases()) + addFoldMemrefAliasPatterns(patterns); + if (getFoldReassociativeReshapes()) + addReassociativeReshapePatterns(patterns); + if (getFoldTensorEmptyExtract()) + addFoldTensorEmptyExtract(patterns); + if (getFoldTensorSubsets()) + addFoldTensorSubsetsPatterns(patterns); + if (getLinalgElementwiseGreedyFusion()) + linalg::populateElementwiseOpsFusionPatterns(patterns, + setFusedOpOperandLimit<3>); + if (getLowerTransferOpPermutations()) + addLowerTransferOpPermutationsPatterns(patterns); + if (getLowerVectorMasks()) + addLowerVectorMasksPatterns(patterns); + if (getPrepareVectorToMma()) + addPrepareVectorToMmaPatterns(patterns); + if (getRankReducingLinalg()) + addRankReducingLinalgPatterns(patterns); + if (getRankReducingLinalgViaReshapes()) + addRankReducingLinalgViaReshapesPatterns(patterns); + if (getRankReducingVector()) + addRankReducingVectorPatterns(patterns); + if (getSwappingPatterns()) + addSwappingPatterns(patterns, getSwapPaddingElideConditional()); + if (getTilingCanonicalization()) + addTilingCanonicalizationPatterns(patterns); + // if (getUnrollVectorsGpuMmaSync()) + // addUnrollVectorsGpuMmaSyncPatterns(patterns); + // if (getUnrollVectorsGpuWmma()) + // addUnrollVectorsGpuWmmaPatterns(patterns); + + Location loc = target->getLoc(); + ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + // Manually gather list of ops because the other GreedyPatternRewriteDriver + // overloads only accepts ops that are isolated from above. + SmallVector ops; + target->walk([&](Operation *nestedOp) { + if (target != nestedOp) + ops.push_back(nestedOp); + }); + LogicalResult result = + applyOpPatternsAndFold(ops, std::move(patterns), config); + if (failed(result)) { + return listener.check( + loc, mlir::emitDefiniteFailure(target, "greedy patterns failed")); + } + + auto diag = listener.check(loc); + if (!diag.succeeded()) + return diag; + + if (getLicm()) { + target->walk([&](func::FuncOp funcOp) { + // This assumes LICM never removes operations so we don't need tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode(loopLike); + }); + // For now, put single loop promotion as part of licm. Underlying + // implementations perform splice operations which shouldn't need + // tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](Operation *op) { + (void)llvm::TypeSwitch(op) + .Case( + [](auto loop) { return promoteIfSingleIteration(loop); }) + .Case( + [](auto loop) { return promoteIfSingleIteration(loop); }) + .Default([](Operation *) { return success(); }); + }); + }); + } + + // if (getCse()) { + // func::FuncOp lastFuncVisited; + // auto walkResult = target->walk([&](func::FuncOp funcOp) -> WalkResult { + // lastFuncVisited = funcOp; + // result = + // eliminateCommonSubexpressions(funcOp, /*domInfo=*/nullptr, + // &listener); + // if (failed(result)) + // return WalkResult::interrupt(); + // if (failed(listener.checkErrorState())) + // return WalkResult::interrupt(); + // return WalkResult::advance(); + // }); + // if (walkResult.wasInterrupted()) { + // if (failed(result)) { + // return mlir::emitDefiniteFailure(lastFuncVisited, + // "greedy patterns failed"); + // } + // if (failed(listener.checkErrorState())) + // return mlir::emitDefiniteFailure(lastFuncVisited, + // "pattern listener tracker fail"); + // } + // } + + return listener.check(loc); +} + +void transform::ApplyPatternsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. +class CommonExtensionsTransformDialectExtension + : public transform::TransformDialectExtension< + CommonExtensionsTransformDialectExtension> { +public: + CommonExtensionsTransformDialectExtension() { + declareDependentDialect(); + declareDependentDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/CommonExtensions/CommonExtensionsOps.cpp.inc" + +void mlir::transform::registerCommonExtensionsTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +}