diff --git a/mlir/include/mlir/Conversion/VectorToGPU/Utils.h b/mlir/include/mlir/Conversion/VectorToGPU/Utils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToGPU/Utils.h @@ -0,0 +1,28 @@ +//===- VectorToGPU.h - Convert vector to GPU dialect ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOGPU_VECTORTOGPU_UTILS_H +#define MLIR_CONVERSION_VECTORTOGPU_VECTORTOGPU_UTILS_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir { +/// Return the native size of an operation used in contraction calculation. +// TODO: Make this take HW specific sizes. +std::optional> getWmmaNativeVectorSize(Operation *op); + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op); + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract); +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOGPU_VECTORTOGPU_UTILS_H diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" @@ -55,6 +56,86 @@ ArrayRef mixedTileSizes, std::optional mapping, SmallVector &tileOps, SmallVector &tiledOps); +/// Selected patterns for ApplyPatternOp. +struct ApplyPatternsOpPatterns { + bool additionalPatterns = false; + bool bubbleCollapse = false; + bool bubbleExpand = false; + bool bubblePackUnPack = false; + bool canonicalization = false; + // bool cse = false; + bool eraseUnnecessaryTensorOperands = false; + bool expandMemrefStridedMetadata = false; + bool extractAddressComputations = false; + bool foldMemrefAliases = false; + bool foldReassociativeReshapes = false; + bool foldTensorEmptyExtract = false; + bool foldTensorSubsets = false; + bool licm = false; + bool linalgElementwiseGreedyFusion = false; + bool lowerTransferOpPermutations = false; + bool lowerVectorMasks = false; + bool prepareVectorToMma = false; + bool rankReducingLinalg = false; + bool rankReducingLinalgViaReshapes = false; + bool rankReducingVector = false; + bool swapPaddingElideConditional = false; + bool swappingPatterns = false; + bool tilingCanonicalization = false; + bool unrollVectorsGpuMmaSync = false; + bool unrollVectorsGpuWmma = false; +}; + +/// A tracking listener for tensor IR that checks for payload replacement +/// errors. +class ErrorCheckingTrackingListener : public tensor::TrackingListener { +public: + using tensor::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + assert((errorStateChecked || !hadErrors) && + "must check listener error state"); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + + DiagnosedSilenceableFailure check(Location loc) { + if (failed(checkErrorState())) + return emitDefiniteFailure(loc, "listener failed"); + return DiagnosedSilenceableFailure::success(); + } + DiagnosedSilenceableFailure check(Location loc, + DiagnosedSilenceableFailure &&diag) { + if (failed(checkErrorState())) { + auto definite = emitDefiniteFailure(loc, "listener failed"); + if (diag.isSilenceableFailure()) { + definite.attachNote() + << "was propagating silenceable error:" << diag.getMessage(); + (void)diag.silence(); + } + return definite; + } + return std::move(diag); + } + + LogicalResult checkErrorState() const { +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = true; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return failure(hadErrors); + } + +private: + void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) override; + + bool hadErrors = false; + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + mutable bool errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +}; + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -27,6 +27,151 @@ Transform_ParamType.predicate]>, "transform 'param' type or any handle type">; +def ApplyPatternsOp : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + Greedily applies patterns as specified by its attributes. + + Must be applied to an op with trait IsolatedFromAbove since the + GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter + to preserve handles to payload operations nested within operations + associated with `target`. Fails if tracking cannot find replacement for a + payload operation. This may become controllable with an attribute in the + future. + + Returns the IsolatedFromAbove op whose content it has modified for better + chaining APIs. + + The following additive attributes can be set, they add patterns in an + unspecified order: + - additional_patterns: fancy patterns (generate-to-constant) we shortcut into the system, + will need to be sliced out better in the future. + - bubble_collapse: bubble `collapse_shape` down across Linalg ops. This + must be applied separately from `bubble_expand` patterns because of some + upstream pattern interference issue atm. + - bubble_expand: bubble `expand_shape` down across Linalg ops. This + must be applied separately from `bubble_collapse` patterns because of some + upstream pattern interference issue atm. + - bubble_pack_un_pack: bubble `pack` up and `unpack` down across Linalg + ops. + - canonicalization: adds all the canonicalization patterns of all + registered dialects and ops. + // - cse: additionally apply common subexpression elimination. This must + // apply on a funcOp. This is not a set of patterns per se but is still very + // convenient to apply it close to canonicalization and other greedy pattern + // applications. + - erase_unnecessary_tensor_operands: add patterns that erase unnecessary + tensor operands. + - expand_memref_strided_metadata: adds patterns that expand memref + operations into extract_strided_metadata operations and a materialization + of their effect on the metadata (sizes, offset, strides). + - extract_address_computations: adds patterns for anchoring subview + accessing operations at [0, ... 0]. + - fold_memref_aliases: adds patterns for folding ops such as + memref.subview. + - fold_reassociative_reshapes: adds patterns that fold insert_slice/ + extract_slice ops with reassociative reshape ops. + - fold_tensor_empty_extract: Fold tensor.empty used by extract_slice in + case it is the only use of extract. + - fold_tensor_subsets: adds patterns for folding tensor subset ops into + their producer and consumers. + - licm: additionally apply loop-independent code motion and single + iteration loop promotion. This is not a set of patterns per se but is still + very convenient to apply it close to canonicalization and other greedy + pattern applications. + - linalg_elementwise_greedy_fusion: add linalg elementwise ops fusion + patterns using a naive default heuristic. + - lower_transfer_op_permutations: Lower transfer ops to transfer ops + with minor identity permutations. + - lower_vector_masks: Lower vector.mask ops away. + - prepare_vector_to_mma: pre-process vector.contract op to set it in a form + that can be mapped to nvgpu.mma operations. + - rank_reducing_linalg: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using insert/extract slices. + - rank_reducing_linalg_via_reshapes: adds patterns that results in rank-reducing + behavior on subset-based linalg operations using expand/collapse shape ops. + - rank_reducing_vector: adds patterns that results in rank-reducing + behavior on subset-based vector operations. + adopts the upstream version. + - swapping_patterns: adds patterns that swap operations for a better outcome. + This is a catch all that can be refined further if/when needed. + - swap_padding_elide_conditional: refines the tensor.pad + + tensor.extract_slice swapping pattern. This injects static information + that guarantees padding is smaller than the window size which guarantees + we never see a tile comprised of padding-only. + - tiling_canonicalization: adds specific tiling-related canonicalization + patterns. + - unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile + size for GPUs with mma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + - unroll_vectors_gpu_wmma: adds patterns that unroll vectors to a native tile + size for GPUs with wmma operations. The size is currently hardcoded but + should be refactored upstream and made pluggable. + + + #### Return modes: + + This operation applies a set of patterns specified by attributes. To apply + these patterns, this operation must target an operation that is isolated + from above, otherwise the transform definitely fails. + + If the pattern application fails, or if the underlying listener fails to + capture op handles, the transformation definitely fails. + + Otherwise the transformation is successful. + + This operation does not consume the target handle and does not produce any + handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$additional_patterns, + UnitAttr:$bubble_collapse, + UnitAttr:$bubble_expand, + UnitAttr:$bubble_pack_un_pack, + UnitAttr:$canonicalization, + // UnitAttr:$cse, + UnitAttr:$erase_unnecessary_tensor_operands, + UnitAttr:$expand_memref_strided_metadata, + UnitAttr:$extract_address_computations, + UnitAttr:$fold_memref_aliases, + UnitAttr:$fold_reassociative_reshapes, + UnitAttr:$fold_tensor_empty_extract, + UnitAttr:$fold_tensor_subsets, + UnitAttr:$licm, + UnitAttr:$linalg_elementwise_greedy_fusion, + UnitAttr:$lower_transfer_op_permutations, + UnitAttr:$lower_vector_masks, + UnitAttr:$prepare_vector_to_mma, + UnitAttr:$rank_reducing_linalg, + UnitAttr:$rank_reducing_linalg_via_reshapes, + UnitAttr:$rank_reducing_vector, + UnitAttr:$swap_padding_elide_conditional, + UnitAttr:$swapping_patterns, + UnitAttr:$tiling_canonicalization, + UnitAttr:$unroll_vectors_gpu_mma_sync, + UnitAttr:$unroll_vectors_gpu_wmma); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + let cppNamespace = "mlir::transform"; + + let builders = [ + // TODO: Some bitvector to scale better than n-bools. + OpBuilder<(ins "Value":$target, + "const ApplyPatternsOpPatterns &":$patterns)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRVectorToGPU VectorToGPU.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU diff --git a/mlir/lib/Conversion/VectorToGPU/Utils.cpp b/mlir/lib/Conversion/VectorToGPU/Utils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/Utils.cpp @@ -0,0 +1,301 @@ +//===-- Utils.cpp - VectorToGPU Infrastructure -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the common initialization infrastructure for the +// VectorToGPU library. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToGPU/Utils.h" + +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "vector-to-gpu-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") + +namespace mlir { +std::optional> getWmmaNativeVectorSize(Operation *op) { + // Currently hardcode the size of wmma operation. When more cases are + // supported this should be picked based on what the backend supports. + int64_t m = 16; + int64_t n = 16; + if (auto contract = dyn_cast(op)) { + int64_t k = contract.getLhsType().getElementType().isF16() ? 16 : 8; + SmallVector nativeSize(contract.getIteratorTypes().size() - 3, 1); + nativeSize.append({m, n, k}); + return nativeSize; + } + if (auto writeOp = dyn_cast(op)) { + SmallVector nativeSize(writeOp.getVectorType().getRank() - 2, 1); + nativeSize.append({m, n}); + return nativeSize; + } + if (auto readOp = dyn_cast(op)) { + // Transfer read ops may need different shapes based on how they are being + // used. For simplicity just match the shape used by the extract strided op. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = extract.getResult().getType().cast(); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + return llvm::to_vector(sliceType.getShape()); + } + if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { + if (auto vecType = op->getResultTypes()[0].dyn_cast()) { + SmallVector nativeSize(vecType.getRank() - 2, 1); + // Map elementwise ops to the output shape. + nativeSize.append({m, n}); + return nativeSize; + } + } + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// getMmaNativeVectorSize +//===----------------------------------------------------------------------===// +/// Returns vector::ContractionOp operand's index where the result is used. +static std::optional +getVectorContractOpOperandId(vector::ContractionOp contractOp, + OpResult result) { + if (contractOp.getLhs() == result) + return 0; + if (contractOp.getRhs() == result) + return 1; + if (contractOp.getAcc() == result) + return 2; + return std::nullopt; +} + +/// Returns vector::ContractionOp operand's index where the +/// vector::TransferReadOp is consumed either consumed directly or via +/// vector::ExtractStridedSliceOp. +static std::optional +getVectorContractOpOperandIdForVectorReadOp(Operation *op) { + vector::ContractionOp contractOp; + + // Check if the vector::TransferReadOp is consumed directly by + // vector::ContractionOp. + if (op->use_empty()) + return std::nullopt; + Operation *firstLevelUser = *((op->getUsers()).begin()); + if (!firstLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(firstLevelUser)) + return getVectorContractOpOperandId(contractOp, op->getResult(0)); + + // Check if the vector::TransferReadOp is consumed indirectly by + // vector::ContractionOp. Only check until the second level of use-def chain. + if (firstLevelUser->use_empty()) + return std::nullopt; + Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); + if (!secondLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(secondLevelUser)) + return getVectorContractOpOperandId(contractOp, + firstLevelUser->getResult(0)); + return std::nullopt; +} + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op) { + // Shape of native Tensor Core GPU mma.sync operations. + int64_t mmaShapeM = 16; + int64_t mmaShapeN = 8; + int64_t mmaShapeK; + + // Shape the mma.sync warp-level operation. + if (auto contract = dyn_cast(op)) { + Type sourceType = contract.getLhsType().getElementType(); + + // Set mmaShapeK based on sourceType. + if (sourceType.isInteger(4)) + mmaShapeK = 64; + else if (sourceType.isInteger(8)) + mmaShapeK = 32; + else if (sourceType.isF16() || sourceType.isBF16()) + mmaShapeK = 16; + else if (sourceType.isF32()) + mmaShapeK = 8; + else + return std::nullopt; + + // Initialize/set the starting dims of the ranked shape, such as batch, + // to 1. + SmallVector mmaShape(contract.getIteratorTypes().size() - 3, 1); + mmaShape.append({mmaShapeM, mmaShapeN, mmaShapeK}); + return mmaShape; + } + + // Shape of warp-level vector write operation. + if (auto writeOp = dyn_cast(op)) { + SmallVector outputShape(writeOp.getVectorType().getRank() - 2, 1); + outputShape.append({mmaShapeM, mmaShapeN}); + return outputShape; + } + + // Shape of warp-level vector read (load) operation. + if (auto readOp = dyn_cast(op)) { + auto resultVectorType = readOp.getVector().getType().cast(); + Type resultElementType = resultVectorType.getElementType(); + + std::optional operandId = + getVectorContractOpOperandIdForVectorReadOp(op); + if (!operandId) { + LLVM_DEBUG({ + DBGS() << "Failed to get operandId for vector::TransferReadOp: " << *op + << "\n"; + }); + return std::nullopt; + } + + // Loading F16 values from Shared Memory to Registers. + if (resultElementType.isF16() || resultElementType.isBF16()) { + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + return readShape; + } + + // For matrixA and matrixB. + if (*operandId == 0 || *operandId == 1) { + // MmaSyncOp input operands: matrixA and matrixB. + // LDSMx1, x2, x4: + // - LDSMx1 loads a 1 tile of 8x8. + // - LDSMx2 loads a 2 tiles of 8x8. + // - LDSMx4 loads a 4 tiles of 8x8. (in use) + // IREE uses the largest tiled load, i.e., LDSMx4. + + // MmaSyncOp source operand: matrixC. + // matrixC is also read/written in tiled block of 16x16. In the pass + // OptimizeVectorTransfer, matrixC reads are moved above the mainloop + // and writes are moved below the mainloop. Thus, mma.sync read/write + // accumulator inplace. + + SmallVector readShape; + readShape.append({16, 16}); + return readShape; + } + } + + // Loading F32 values from Shared Memory to Registers. + if (resultElementType.isF32()) { + // Set mmaShapeK for F32 datatype mma.sync.f32.tf32.m16n8k8. + mmaShapeK = 8; + + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + return readShape; + } + // For matrixA. + if (*operandId == 0) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeK}); + return readShape; + } + // For matrixB. + if (*operandId == 1) { + // Do not use ldmatrix for matrixB. + // Transfer read ops may need different shapes based on how they are + // being used. For simplicity just match the shape used by the extract + // strided op. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = extract.getResult().getType().cast(); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + return llvm::to_vector(sliceType.getShape()); + } + } + } + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// GPU vectorization +//===----------------------------------------------------------------------===// + +bool canPerformVectorAccessUsingAllThreads(ArrayRef shape, + int64_t threadCount, + int64_t vectorSize) { + // Verify that each dimension of the shape can be distributed on the + // threads + // For zero dim tensor, consider it's too small to access using all threads. + if (shape.size() == 0) + return false; + int64_t threadsAvailable = threadCount; + for (const auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) { + int64_t numElementPerThread = index == 0 ? vectorSize : 1; + int64_t numThreads = dim / numElementPerThread; + if (numThreads == 0) + return false; + if (numThreads > threadsAvailable) { + // If there are no enough remaining threads to distribute the current + // dimension, try to use all remaining threads. But we still need to make + // sure all work can be distributed to these threads evenly. + if (numThreads % threadsAvailable != 0) + return false; + numThreads = threadsAvailable; + } + if (threadsAvailable % numThreads != 0) + return false; + threadsAvailable = threadsAvailable / numThreads; + if (threadsAvailable == 1) + break; + } + return threadsAvailable == 1; +} + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract) { + SmallVector order; + // First make reduction the outer dimensions. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isReductionIterator(iter)) { + order.push_back(index); + } + } + + llvm::SmallDenseSet dims; + for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { + dims.insert(expr.cast().getPosition()); + } + // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && dims.count(index)) { + order.push_back(index); + } + } + // Then the remaining parallel loops. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && !dims.count(index)) { + order.push_back(index); + } + } + return order; +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -12,8 +12,6 @@ #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" -#include - #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -25,9 +23,7 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Region.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -18,11 +18,14 @@ MLIRIR MLIRLinalgDialect MLIRLinalgTransforms - MLIRParser MLIRPDLDialect + MLIRParser MLIRSCFDialect MLIRSideEffectInterfaces + MLIRTensorTransformOps MLIRTransformDialect MLIRTransformDialectUtils + MLIRVectorToGPU MLIRVectorTransforms + MLIRVectorUtils ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -10,19 +10,22 @@ #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Conversion/VectorToGPU/Utils.h" +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" @@ -35,6 +38,7 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" @@ -141,6 +145,454 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ErrorCheckingTrackingListener +//===----------------------------------------------------------------------===// + +void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values) { + // Certain ops can dropped safely. + if (isa(op)) { + LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); + return; + } + + hadErrors = true; +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + errorStateChecked = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +} + +//===---------------------------------------------------------------------===// +// ApplyPatternsOp +//===---------------------------------------------------------------------===// +void transform::ApplyPatternsOp::build( + OpBuilder &builder, OperationState &result, Value target, + const ApplyPatternsOpPatterns &patterns) { + result.addOperands(target); + + auto unitAttr = builder.getUnitAttr(); + +#define ADD_PATTERN(NAME, ATTR) \ + if (patterns.NAME) \ + result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr); + /// + /// When touching something here, do not forget to update + /// LinalgTransformOps.h. + /// + ADD_PATTERN(additionalPatterns, getAdditionalPatternsAttrName) + ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName) + ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName) + ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName) + ADD_PATTERN(canonicalization, getCanonicalizationAttrName) + // ADD_PATTERN(cse, getCseAttrName) + ADD_PATTERN(eraseUnnecessaryTensorOperands, + getEraseUnnecessaryTensorOperandsAttrName) + ADD_PATTERN(expandMemrefStridedMetadata, + getExpandMemrefStridedMetadataAttrName) + ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName) + ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName) + ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName) + ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName) + ADD_PATTERN(foldTensorSubsets, getFoldTensorSubsetsAttrName) + ADD_PATTERN(licm, getLicmAttrName) + ADD_PATTERN(linalgElementwiseGreedyFusion, + getLinalgElementwiseGreedyFusionAttrName) + ADD_PATTERN(lowerTransferOpPermutations, + getLowerTransferOpPermutationsAttrName) + ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName) + ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName) + ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName) + ADD_PATTERN(rankReducingLinalgViaReshapes, + getRankReducingLinalgViaReshapesAttrName) + ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName) + ADD_PATTERN(swapPaddingElideConditional, + getSwapPaddingElideConditionalAttrName) + ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName) + ADD_PATTERN(tilingCanonicalization, getTilingCanonicalizationAttrName) + ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName) + ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName) +#undef ADD_PATTERN +} + +static void addOperands(Operation *op, SetVector &operandSet) { + if (!op) + return; + TypeSwitch(op) + .Case([&](linalg::LinalgOp linalgOp) { + SmallVector inputOperands{linalgOp.getDpsInputOperands()}; + operandSet.insert(inputOperands.begin(), inputOperands.end()); + }) + .Default([&](Operation *operation) { + operandSet.insert(operation->operand_begin(), operation->operand_end()); + }); +} + +template +static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + if (!producer) + return false; + Operation *consumer = fusedOperand->getOwner(); + SetVector fusedOpOperands; + if (producer->getNumResults() != 1) + return false; + addOperands(consumer, fusedOpOperands); + fusedOpOperands.remove(producer->getResult(0)); + addOperands(producer, fusedOpOperands); + return fusedOpOperands.size() <= limit; +} + +namespace { +/// Rewrite a tensor.generate as an arith.constant when possible. +struct GenerateToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::GenerateOp generateOp, + PatternRewriter &rewriter) const final { + auto tensorType = generateOp.getResult().getType().cast(); + if (!tensorType.hasStaticShape()) + return failure(); + auto terminatorOp = + cast(generateOp.getBody().front().getTerminator()); + if (terminatorOp->getNumOperands() > 1) + return failure(); + auto constantOp = + terminatorOp->getOperand(0).getDefiningOp(); + if (!constantOp) + return failure(); + rewriter.replaceOpWithNewOp( + generateOp, tensorType, + DenseElementsAttr::get(tensorType, constantOp.getValueAttr())); + return success(); + } +}; + +/// Fold tensor.empty used by extract_slice if this the only use of +/// extract_slice and the result is static. +struct FoldTensorEmptyExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const final { + auto tensorEmpty = extractOp.getSource().getDefiningOp(); + if (!tensorEmpty || !extractOp.getType().hasStaticShape() || + !tensorEmpty->hasOneUse()) + return failure(); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType().getShape(), + extractOp.getType().getElementType()); + return success(); + } +}; + +/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into +/// `linalg.fill(cst, empty)` when the padding constant and the fill constant +/// are the same. +/// This seems generally desirable as a folding but may be too intrusive, so we +/// only apply it selectively for now. +// TODO: atm hardcoded on linalg.fill but we could take any result of any +// generic that yields a constant in that result. +struct FoldFillIntoPad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const final { + Operation *currentOp = padOp.getSource().getDefiningOp(); + auto maybeExtractSlice = + dyn_cast_or_null(currentOp); + while (currentOp && maybeExtractSlice) { + currentOp = maybeExtractSlice.getSource().getDefiningOp(); + maybeExtractSlice = dyn_cast_or_null(currentOp); + } + auto fillOp = dyn_cast_or_null(currentOp); + if (!fillOp) { + return rewriter.notifyMatchFailure( + padOp, "not coming from a linalg.fill op via tensor.extract_slice*"); + } + + Value padValue = padOp.getConstantPaddingValue(); + RankedTensorType resultType = padOp.getResultType(); + if (!padValue || + getAsOpFoldResult(padValue) != + getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) { + return rewriter.notifyMatchFailure( + padOp, "not a constant value matching the fill value"); + } + + Location loc = padOp.getLoc(); + auto emptyOp = rewriter.create( + loc, resultType, + linalg::createDynamicDimensions(rewriter, loc, padOp.getResult())); + rewriter.replaceOpWithNewOp(padOp, padValue, + emptyOp.getResult()); + + return success(); + } +}; +} // namespace + +static void +addLowerTransferOpPermutationsPatterns(RewritePatternSet &patterns) { + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); +} + +static void addLowerVectorMasksPatterns(RewritePatternSet &patterns) { + vector::populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); +} + +static void addExtractAddressComputationsPatterns(RewritePatternSet &patterns) { + memref::populateExtractAddressComputationsPatterns(patterns); +} + +static void addFoldMemrefAliasPatterns(RewritePatternSet &patterns) { + memref::populateFoldMemRefAliasOpPatterns(patterns); +} + +static void addFoldTensorEmptyExtract(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void addReassociativeReshapePatterns(RewritePatternSet &patterns) { + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +static void addFoldTensorSubsetsPatterns(RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetOpPatterns(patterns); + // TODO: upstream should move these to populateFoldTensorSubsetOpPatterns. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); +} + +static void +addEraseUnnecessaryTensorOperandsPatterns(RewritePatternSet &patterns) { + linalg::populateEraseUnnecessaryInputsPatterns(patterns); +} + +static void addPrepareVectorToMmaPatterns(RewritePatternSet &patterns) { + populatePrepareVectorToMMAPatterns(patterns, /*useNvGpu=*/true); +} + +static void addRankReducingLinalgPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); +} + +static void +addRankReducingLinalgViaReshapesPatterns(RewritePatternSet &patterns) { + // populateReshapeToInterfaceTensorPatterns(patterns); + linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); +} + +static void addRankReducingVectorPatterns(RewritePatternSet &patterns) { + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); +} + +static void addSwappingPatterns(RewritePatternSet &patterns, + bool swapPaddingElideCornerCase) { + patterns.add( + patterns.getContext(), + [&](tensor::ExtractSliceOp) -> std::optional { + return !swapPaddingElideCornerCase; + }); +} + +static void addTilingCanonicalizationPatterns(RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + /// This seems generally desirable as a folding but may be too intrusive, so + /// we only apply it selectively for now. + patterns.add(patterns.getContext()); +} + +static std::optional> +getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) { + return mlir::getMmaNativeVectorSize(op); +} + +static void addUnrollVectorsGpuMmaSyncPatterns(RewritePatternSet &patterns) { + auto unrollOrder = [](Operation *op) -> std::optional> { + auto contract = dyn_cast(op); + if (!contract) + return std::nullopt; + return mlir::gpuMmaUnrollOrder(contract); + }; + vector::populateVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(getGPUTensorCoreNativeMmaSyncVectorSize) + .setUnrollTraversalOrderFn(unrollOrder)); +} + +static std::optional> +getGPUTensorCoreNativeWmmaVectorSize(Operation *op) { + return getWmmaNativeVectorSize(op); +} + +static void addUnrollVectorsGpuWmmaPatterns(RewritePatternSet &patterns) { + auto unrollOrder = [](Operation *op) -> std::optional> { + auto contract = dyn_cast(op); + if (!contract) + return std::nullopt; + return mlir::gpuMmaUnrollOrder(contract); + }; + vector::populateVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(getGPUTensorCoreNativeWmmaVectorSize) + .setUnrollTraversalOrderFn(unrollOrder)); +} + +static void addAdditionalPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +static void +addAllRegisteredCanonicalizationPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); +} + +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure( + target, + "applies only to isolated-from-above targets because it needs to apply " + "patterns greedily"); + } + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + if (getAdditionalPatterns()) + addAdditionalPatterns(patterns); + if (getBubbleCollapse()) { + linalg::populateFoldReshapeOpsByCollapsingPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubbleExpand()) { + linalg::populateFoldReshapeOpsByExpansionPatterns( + patterns, [](OpOperand *) { return true; }); + } + if (getBubblePackUnPack()) + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return true; }); + if (getCanonicalization()) + addAllRegisteredCanonicalizationPatterns(patterns); + if (getEraseUnnecessaryTensorOperands()) + addEraseUnnecessaryTensorOperandsPatterns(patterns); + if (getExpandMemrefStridedMetadata()) + memref::populateExpandStridedMetadataPatterns(patterns); + if (getExtractAddressComputations()) + addExtractAddressComputationsPatterns(patterns); + if (getFoldMemrefAliases()) + addFoldMemrefAliasPatterns(patterns); + if (getFoldReassociativeReshapes()) + addReassociativeReshapePatterns(patterns); + if (getFoldTensorEmptyExtract()) + addFoldTensorEmptyExtract(patterns); + if (getFoldTensorSubsets()) + addFoldTensorSubsetsPatterns(patterns); + if (getLinalgElementwiseGreedyFusion()) + linalg::populateElementwiseOpsFusionPatterns(patterns, + setFusedOpOperandLimit<3>); + if (getLowerTransferOpPermutations()) + addLowerTransferOpPermutationsPatterns(patterns); + if (getLowerVectorMasks()) + addLowerVectorMasksPatterns(patterns); + if (getPrepareVectorToMma()) + addPrepareVectorToMmaPatterns(patterns); + if (getRankReducingLinalg()) + addRankReducingLinalgPatterns(patterns); + if (getRankReducingLinalgViaReshapes()) + addRankReducingLinalgViaReshapesPatterns(patterns); + if (getRankReducingVector()) + addRankReducingVectorPatterns(patterns); + if (getSwappingPatterns()) + addSwappingPatterns(patterns, getSwapPaddingElideConditional()); + if (getTilingCanonicalization()) + addTilingCanonicalizationPatterns(patterns); + if (getUnrollVectorsGpuMmaSync()) + addUnrollVectorsGpuMmaSyncPatterns(patterns); + if (getUnrollVectorsGpuWmma()) + addUnrollVectorsGpuWmmaPatterns(patterns); + + Location loc = target->getLoc(); + ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + // Manually gather list of ops because the other GreedyPatternRewriteDriver + // overloads only accepts ops that are isolated from above. + SmallVector ops; + target->walk([&](Operation *nestedOp) { + if (target != nestedOp) + ops.push_back(nestedOp); + }); + LogicalResult result = + applyOpPatternsAndFold(ops, std::move(patterns), config); + if (failed(result)) { + return listener.check( + loc, mlir::emitDefiniteFailure(target, "greedy patterns failed")); + } + + auto diag = listener.check(loc); + if (!diag.succeeded()) + return diag; + + if (getLicm()) { + target->walk([&](func::FuncOp funcOp) { + // This assumes LICM never removes operations so we don't need tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode(loopLike); + }); + // For now, put single loop promotion as part of licm. Underlying + // implementations perform splice operations which shouldn't need + // tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](Operation *op) { + (void)llvm::TypeSwitch(op) + .Case( + [](auto loop) { return promoteIfSingleIteration(loop); }) + .Default([](Operation *) { return success(); }); + }); + }); + } + + // if (getCse()) { + // func::FuncOp lastFuncVisited; + // auto walkResult = target->walk([&](func::FuncOp funcOp) -> WalkResult { + // lastFuncVisited = funcOp; + // result = + // eliminateCommonSubexpressions(funcOp, /*domInfo=*/nullptr, + // &listener); + // if (failed(result)) + // return WalkResult::interrupt(); + // if (failed(listener.checkErrorState())) + // return WalkResult::interrupt(); + // return WalkResult::advance(); + // }); + // if (walkResult.wasInterrupted()) { + // if (failed(result)) { + // return mlir::emitDefiniteFailure(lastFuncVisited, + // "greedy patterns failed"); + // } + // if (failed(listener.checkErrorState())) + // return mlir::emitDefiniteFailure(lastFuncVisited, + // "pattern listener tracker fail"); + // } + // } + + return listener.check(loc); +} + +void transform::ApplyPatternsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -142,6 +142,36 @@ return makePermutationMap(op->getBlock(), indices, loopToVectorDim); } +/// Returns vector::ContractionOp operand's index where the result is used. +static Optional +getVectorContractOpOperandId(vector::ContractionOp contractOp, + OpResult result) { + if (contractOp.getLhs() == result) + return 0; + if (contractOp.getRhs() == result) + return 1; + if (contractOp.getAcc() == result) + return 2; + return std::nullopt; +} + +/// Returns vector::ContractionOp operand's index where the +/// vector::TransferReadOp is consumed either consumed directly or via +/// vector::ExtractStridedSliceOp. +static Optional +getVectorContractOpOperandIdForVectorReadOp(Operation *op) { + vector::ContractionOp contractOp; + + Operation *firstLevelUser = *((op->getUsers()).begin()); + if (auto contractOp = dyn_cast(firstLevelUser)) + return getVectorContractOpOperandId(contractOp, op->getResult(0)); + Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); + if (auto contractOp = dyn_cast(secondLevelUser)) + return getVectorContractOpOperandId(contractOp, + firstLevelUser->getResult(0)); + return std::nullopt; +} + bool matcher::operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType) { // First, extract the vector type and distinguish between: diff --git a/mlir/test/Dialect/Linalg/transform-op-apply-patterns.mlir b/mlir/test/Dialect/Linalg/transform-op-apply-patterns.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-apply-patterns.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -split-input-file \ +// RUN: -test-transform-dialect-interpreter -canonicalize \ +// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @select_cmp_eq_select +// CHECK: return %arg1 +func.func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = arith.cmpi eq, %arg0, %arg1 : i64 + %1 = arith.select %0, %arg0, %arg1 : i64 + return %1 : i64 +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 : !pdl.operation failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { canonicalization } : (!pdl.operation) -> () + } +} + +// ----- + +#map2 = affine_map<(d0, d1) -> (d0, d1)> + +func.func private @mutate(f32) -> f32 + +// CHECK-LABEL: @bubble_up +func.func @bubble_up(%arg0: tensor<32x64xf32>) -> tensor<32x2x32xf32> { + // Check that shape expansion precedes linalg.generic after the patterns were applied. + // CHECK: tensor.expand_shape + // CHECK: tensor.expand_shape + // CHECK: linalg.generic + %init = tensor.empty() : tensor<32x64xf32> + %result = linalg.generic { + indexing_maps = [#map2, #map2], + iterator_types = ["parallel", "parallel"]} + ins(%arg0: tensor<32x64xf32>) outs(%init: tensor<32x64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %0 = func.call @mutate(%arg1) : (f32) -> f32 + linalg.yield %0 : f32 + } -> tensor<32x64xf32> + %out = tensor.expand_shape %result[[0], [1, 2]] : tensor<32x64xf32> into tensor<32x2x32xf32> + return %out : tensor<32x2x32xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { bubble_expand } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_fill_to_fill +func.func @pad_fill_to_fill(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %padded = tensor.pad %fill low[%c0, %c0] high[%c1, %c2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<31x62xf32> to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_fill_different_ssa_value_but_same_cst +func.func @pad_fill_different_ssa_value_but_same_cst(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the constant comes from different ssa value. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %padded = tensor.pad %fill low[%c0, %c0] high[%c1, %c2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor<31x62xf32> to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_fill_to_fill +func.func @pad_extract_fill_to_fill(%arg0: tensor<31x62xf32>, + %size0 : index, %size1 : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the fill is hidden behind an extract_slice. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %extracted_slice = tensor.extract_slice %fill[0, 0] [%size0, %size1] [1, 1] : tensor<31x62xf32> to tensor + %padded = tensor.pad %extracted_slice low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_extract_fill_to_fill +func.func @pad_extract_extract_fill_to_fill(%arg0: tensor<31x62xf32>, + %size0a : index, %size1a : index, + %size0b : index, %size1b : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill even when the fill is hidden behind a few `extract_slice`s. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %extracted_sliceA = tensor.extract_slice %fill[0, 0] [%size0a, %size1a] [1, 1] : tensor<31x62xf32> to tensor + %extracted_sliceB = tensor.extract_slice %extracted_sliceA[0, 0] [%size0b, %size1b] [1, 1] : tensor to tensor + %padded = tensor.pad %extracted_sliceB low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} + +// ----- + +// CHECK-LABEL: @pad_extract_bigger_fill_to_fill +func.func @pad_extract_bigger_fill_to_fill(%arg0: tensor<253x123xf32>, + %size0 : index, %size1 : index, + %high0 : index, %high1 : index) -> tensor<32x64xf32> { + // Check that a pad of a bigger fill with the same constant is replaced by a + // fill of the right size. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<253x123xf32>) -> tensor<253x123xf32> + %extracted_slice = tensor.extract_slice %fill[0, 0] [%size0, %size1] [1, 1] : tensor<253x123xf32> to tensor + %padded = tensor.pad %extracted_slice low[%c0, %c0] high[%high0, %high1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst2 : f32 + } : tensor to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.apply_patterns %0 { tiling_canonicalization } : (!pdl.operation) -> () +} +