diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -27,167 +27,310 @@ #include "llvm/ADT/SmallSet.h" namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization - -class FrozenRewritePatternSet; - namespace linalg { -struct LinalgElementwiseFusionOptions; -struct LinalgFusionOptions; -struct LinalgTilingOptions; +class LinalgOp; //===----------------------------------------------------------------------===// -// Transformations exposed as function calls. +// Utils. //===----------------------------------------------------------------------===// -using LinalgLoops = SmallVector; - -/// Materialize a buffer allocation for the given tensor.pad op and lower the -/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: -/// -/// %0 = tensor.pad low[%l] high[%h] %t ... -/// -/// is lowered to: -/// -/// %alloc = memref.alloc -/// linalg.fill ... outs(%alloc) -/// %subview = memref.subview %alloc [%l] [...] [1] -/// memref.tensor_store %t, %subview -/// %0 = bufferization.to_tensor %alloc restrict writable -/// -/// In addition to rewriting the IR as shown above, the result of the -/// bufferization.to_tensor op is returned. -Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, - Attribute memorySpace = {}); - -/// Materialize a buffer allocation for the given tensor value. E.g.: -/// -/// %alloc = memref.alloc -/// memref.tensor_store %value, %alloc -/// %0 = bufferization.to_tensor %alloc restrict writable -/// -/// In case `value` is a tensor.pad result, the corresponding overload is used -/// internally to produce a better bufferization. -Value bufferizeToAllocation(RewriterBase &rewriter, Value value, - Attribute memorySpace = {}); -void populatePadTensorTilingPatterns(RewritePatternSet &patterns, - const LinalgTilingOptions &options); - -/// Populate patterns for splitting a `LinalgOp` with multiple statements within -/// its payload into multiple `GenericOp` that have a single statement. -/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments -/// and results from the generated decomposed ops. This is default `true` since -/// the core decomposition patterns relies on these clean up patterns. It is set -/// to false only for testing purposes. -void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, - bool removeDeadArgsAndResults = true); - -/// Populate patterns that convert non-destination-style ops to destination -/// style ops. -void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); - -/// Populate patterns for vectorizing low-D convolution ops. This is a step in -/// progressive lowering for convolution ops, it assume high-D convolution ops -/// were decomposed previously. -void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Populate patterns that convert `ElementwiseMappable` ops to linalg -/// parallel loops. -void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); +/// Return vector::CombiningKind for the given op. +std::optional getCombinerOpKind(Operation *combinerOp); -/// Populate patterns that are only useful in the context of sparse tensors. -void populateSparseTensorRewriting(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Structs that configure the behavior of various transformations. +//===----------------------------------------------------------------------===// -/// Function type which is used to control when to stop fusion. It is expected -/// that OpOperand is not modified in the callback. The OpOperand is not marked -/// as const to allow callers to use non-const methods. -using ControlFusionFn = std::function; +using TileSizeComputationFunction = + std::function(OpBuilder &, Operation *)>; -/// Patterns for fusing linalg operation on tensors. +struct LinalgTilingOptions { + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + TileSizeComputationFunction tileSizeComputationFunction = nullptr; -/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations -/// when both operations are fusable elementwise operations. -void populateElementwiseOpsFusionPatterns( - RewritePatternSet &patterns, - const ControlFusionFn &controlElementwiseOpFusion); + LinalgTilingOptions & + setTileSizeComputationFunction(TileSizeComputationFunction fun) { + tileSizeComputationFunction = std::move(fun); + return *this; + } + /// Set the `tileSizeComputationFunction` to return the values `ts`. The + /// values must not fold away when tiling. Otherwise, use a more robust + /// `tileSizeComputationFunction`. + LinalgTilingOptions &setTileSizes(const SmallVector &ts) { + tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; + return *this; + } + /// Convenience function to set the `tileSizeComputationFunction` to a + /// function that computes tile sizes at the point they are needed. Allows + /// proper interaction with folding. + LinalgTilingOptions &setTileSizes(ArrayRef ts); -/// Patterns to bubble up or down data layout ops across other operations. -void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns); + /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. + /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. + LinalgTilingOptions &scalarizeDynamicDims(); -/// Pattern to remove dead operands and results of `linalg.generic` operations. -/// This is effectively DCE for a linalg op. -void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; -/// Patterns to promote inputs to outputs and remove unused inputs of -/// `linalg.generic` ops. -void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); + LinalgTilingOptions &setInterchange(ArrayRef interchange) { + interchangeVector.assign(interchange.begin(), interchange.end()); + return *this; + } -/// Function type to control generic op dimension collapsing. It is expected -/// to return an array of `ReassociationIndices` representing dimensions that -/// should be merged. -using GetCollapsableDimensionsFn = - std::function(linalg::GenericOp)>; + /// The type of tile loops to generate. + LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; -/// Pattern to collapse dimensions in a linalg.generic op. This will collapse -/// tensor operands when needed and expand back the result tensors. -void populateCollapseDimensions( - RewritePatternSet &patterns, - const GetCollapsableDimensionsFn &controlCollapseDimensions); + LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { + loopType = lt; + return *this; + } -/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its -/// producer (consumer) generic operation by expanding the dimensionality of the -/// loop in the generic op. -void populateFoldReshapeOpsByExpansionPatterns( - RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); + /// When specified, specifies distribution of generated tile loops to + /// processors. + std::optional distribution; -/// Patterns to fold an expanding tensor.expand_shape operation with its -/// producer generic operation by collapsing the dimensions of the generic op. -void populateFoldReshapeOpsByCollapsingPatterns( - RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); + LinalgTilingOptions & + setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { + distribution = std::move(distributionOptions); + return *this; + } -/// Patterns to constant fold Linalg operations. -void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, - const ControlFusionFn &controlFn); + /// Specification markers of how to distribute the `linalg.tiled_loop`. + SmallVector distributionTypes = {}; -/// Pattern to fuse a `tensor.pad` operation with the producer of its source, -/// if the producer is a `linalg` operation with all parallel iterator types. -void populateFuseTensorPadWithProducerLinalgOpPatterns( - RewritePatternSet &patterns); + LinalgTilingOptions &setDistributionTypes(ArrayRef types) { + distributionTypes.assign(types.begin(), types.end()); + return *this; + } -/// Patterns to convert from one named op to another. These can be seen as -/// canonicalizations of named ops into another named op. -void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); + /// Peel the specified loops. + SmallVector peeledLoops; -/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on -/// tensors via reassociative reshape ops. -void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns); + LinalgTilingOptions &setPeeledLoops(ArrayRef loops) { + peeledLoops.clear(); + peeledLoops.append(loops.begin(), loops.end()); + return *this; + } +}; -/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on -/// tensors via rank-reducing slices. -void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns); +struct LinalgTilingAndFusionOptions { + /// Tile sizes used to tile the root operation. + SmallVector tileSizes; + LinalgTilingAndFusionOptions &setTileSizes(ArrayRef ts) { + tileSizes.assign(ts.begin(), ts.end()); + return *this; + } + /// Tile interchange used to permute the tile loops. + SmallVector tileInterchange; + /// When specified, specifies distribution of generated tile loops to + /// processors. + std::optional tileDistribution; + LinalgTilingAndFusionOptions & + setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { + tileDistribution = std::move(distributionOptions); + return *this; + } +}; -/// A pattern that converts init operands to input operands. -void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); +struct LinalgPaddingOptions { + /// A padding value for every operand. + SmallVector paddingValues; + LinalgPaddingOptions &setPaddingValues(ArrayRef pv) { + paddingValues.assign(pv.begin(), pv.end()); + return *this; + } + /// A list of iterator dimensions to pad. + SmallVector paddingDimensions; + LinalgPaddingOptions &setPaddingDimensions(ArrayRef pd) { + paddingDimensions.assign(pd.begin(), pd.end()); + return *this; + } + /// A flag for every operand to mark the PadOp as nofold which enables + /// packing for statically shaped operands. + SmallVector packPaddings; + LinalgPaddingOptions &setPackPaddings(ArrayRef pp) { + packPaddings.assign(pp.begin(), pp.end()); + return *this; + } + /// A number of loops to hoist the PadOp out for every operand. + SmallVector hoistPaddings; + LinalgPaddingOptions &setHoistPaddings(ArrayRef hp) { + hoistPaddings.assign(hp.begin(), hp.end()); + return *this; + } + /// A permutation vector for every operand used to transpose the packed + /// PadOp results. + SmallVector> transposePaddings; + LinalgPaddingOptions & + setTransposePaddings(ArrayRef> tp) { + transposePaddings.assign(tp.begin(), tp.end()); + return *this; + } +}; -/// Patterns that are used to inline constant operands into linalg generic ops. -void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); +/// Callback function type used to perform the allocation for the promoted +/// `subView`. In `boundingSubViewsize` a best attempt is made to find the +/// smallest constant value for the size of the buffer needed for each +/// dimension. If that is not possible, contains the dynamic size of the +/// subview. The call back should return the buffer to use. +using AllocBufferCallbackFn = std::function( + OpBuilder &b, memref::SubViewOp subView, + ArrayRef boundingSubViewSize, DataLayout &layout)>; -/// Patterns that are used to bubble up extract slice op above linalg op. -void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Callback function type used to deallocate the buffers used to hold the +/// promoted subview. +using DeallocBufferCallbackFn = + std::function; -/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into -/// linalg.fill(%cst, tensor.extract_slice(%init)). -void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); +/// Callback function type used to insert copy from original subview to +/// subview of the promoted region for the read operands/subview of promoted +/// region to original subview for the results. The copy has to happen from +/// `src` to `dst`. +using CopyCallbackFn = + std::function; + +struct LinalgPromotionOptions { + /// Indices of subViews to promote. If `std::nullopt`, try to promote all + /// operands. + std::optional> operandsToPromote; + LinalgPromotionOptions &setOperandsToPromote(ArrayRef operands) { + operandsToPromote = DenseSet(); + operandsToPromote->insert(operands.begin(), operands.end()); + return *this; + } + /// If ith element of `useFullTiles` is true the full view should be used + /// for the promoted buffer of the ith operand in `operandsToPromote`. + /// Otherwise the partial view will be used. The decision is defaulted to + /// `useFullTileBuffersDefault` when `useFullTileBuffers` is None and for + /// operands missing from `useFullTileBuffers`. + std::optional useFullTileBuffers; + LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef useFullTiles) { + unsigned size = useFullTiles.size(); + llvm::SmallBitVector tmp(size, false); + for (unsigned i = 0; i < size; ++i) + tmp[i] = useFullTiles[i]; + useFullTileBuffers = tmp; + return *this; + } + /// If true all operands unspecified by `useFullTileBuffers` will use the + /// full view, otherwise the partial view. + bool useFullTileBuffersDefault = false; + LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { + useFullTileBuffersDefault = use; + return *this; + } + /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment. + std::optional alignment; + LinalgPromotionOptions &setAlignment(unsigned align) { + alignment = align; + return *this; + } + /// Use alloca with the default allocation scheme. + bool useAlloca = false; + LinalgPromotionOptions &setUseAlloca(bool use) { + useAlloca = use; + return *this; + } + /// Callback function to do the allocation of the promoted buffer. If + /// std::nullopt, then the default allocation scheme of allocating a + /// memref buffer followed by a view operation is used. + std::optional allocationFn; + std::optional deallocationFn; + LinalgPromotionOptions & + setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, + DeallocBufferCallbackFn const &deallocFn) { + allocationFn = allocFn; + deallocationFn = deallocFn; + return *this; + } + /// Callback function to do the copy of data to and from the promoted + /// subview. If std::nullopt then a memref.copy is used. + std::optional copyInFn; + std::optional copyOutFn; + LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, + CopyCallbackFn const ©Out) { + copyInFn = copyIn; + copyOutFn = copyOut; + return *this; + } +}; + +/// Split Reduction options. +struct SplitReductionOptions { + // Ratio used to split the reduction dimension. If the ratio is <= 1, + // nothing will be done. + int64_t ratio = 0; + // Index where the extra dimension is added to the intermediate tensor + // shape. + unsigned index = 0; + // If the inner dimension after splitting is parallel or reduction. + bool innerParallel = false; +}; + +/// Function signature to control reduction splitting. This returns +/// `SplitReductionOptions`. +// TODO: don't use unsigned unless doing bit manipulation. +using ControlSplitReductionFn = + std::function; + +//===----------------------------------------------------------------------===// +// Preconditions that ensure the corresponding transformation succeeds and can +// be applied as a rewrite pattern. +//===----------------------------------------------------------------------===// /// Return true if two `linalg.generic` operations with producer/consumer /// relationship through `fusedOperand` can be fused using elementwise op /// fusion. bool areElementwiseOpsFusable(OpOperand *fusedOperand); +/// Promote memref.subviews feeding linalg-on-buffers operations. +LogicalResult promoteSubviewsPrecondition(Operation *op, + LinalgPromotionOptions options); + +/// Return success if the operation can be vectorized. +LogicalResult +vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + ArrayRef inputVectorSizes = {}, + bool vectorizeNDExtract = false); + +//===----------------------------------------------------------------------===// +// Transformations exposed as functional-style API calls. +//===----------------------------------------------------------------------===// + +using LinalgLoops = SmallVector; + +/// Materialize a buffer allocation for the given tensor.pad op and lower the +/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: +/// +/// %0 = tensor.pad low[%l] high[%h] %t ... +/// +/// is lowered to: +/// +/// %alloc = memref.alloc +/// linalg.fill ... outs(%alloc) +/// %subview = memref.subview %alloc [%l] [...] [1] +/// memref.tensor_store %t, %subview +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In addition to rewriting the IR as shown above, the result of the +/// bufferization.to_tensor op is returned. +Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, + Attribute memorySpace = {}); + +/// Materialize a buffer allocation for the given tensor value. E.g.: +/// +/// %alloc = memref.alloc +/// memref.tensor_store %value, %alloc +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In case `value` is a tensor.pad result, the corresponding overload is used +/// internally to produce a better bufferization. +Value bufferizeToAllocation(RewriterBase &rewriter, Value value, + Attribute memorySpace = {}); + /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. @@ -198,6 +341,31 @@ FailureOr fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); +/// Try to peel and canonicalize loop `op` and return the new result. +/// Also applies affine_min/max bounds simplification on the fly where relevant. +// TODO: Add support for scf.parallel and affine.for loops. +SmallVector peelLoop(RewriterBase &rewriter, Operation *op); + +/// Peel 'loops' and applies affine_min/max bounds simplification on the fly +/// where relevant. +void peelLoops(RewriterBase &rewriter, ArrayRef loops); + +/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands +/// to a static bounding box. Use `paddingValues` and `packPaddings` to set +/// padding value and nofold attribute of the created tensor::PadOps, +/// respectively. Update `paddedOp` to the cloned operation with statically +/// shaped `paddingDimensions` and return the extracted dynamically shaped +/// results. If padding fails, return failure. +FailureOr> +rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, + ArrayRef paddingDimensions, + ArrayRef paddingValues, + ArrayRef packPaddings, LinalgOp &paddedOp); + +/// Apply padding to `linalgOp` +FailureOr padLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, + LinalgPaddingOptions options); + /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. /// If the second part is statically known to be empty, do not create it @@ -253,12 +421,6 @@ FailureOr tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options); -/// Try to peel and canonicalize loop `op` and return the new result. -// TODO: Add support for scf.parallel and affine.for loops. -SmallVector peelLoop(RewriterBase &rewriter, Operation *op); -/// Peel and canonicalize 'loops'. -void peelLoops(RewriterBase &rewriter, ArrayRef loops); - /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts /// the index accesses of `op`. This is an in-place transformation controlled /// by `interchangeVector`. An empty vector is interpreted as the identity @@ -280,93 +442,6 @@ FailureOr generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp); -/// Callback function type used to perform the allocation for the promoted -/// `subView`. In `boundingSubViewsize` a best attempt is made to find the -/// smallest constant value for the size of the buffer needed for each -/// dimension. If that is not possible, contains the dynamic size of the -/// subview. The call back should return the buffer to use. -using AllocBufferCallbackFn = std::function( - OpBuilder &b, memref::SubViewOp subView, - ArrayRef boundingSubViewSize, DataLayout &layout)>; - -/// Callback function type used to deallocate the buffers used to hold the -/// promoted subview. -using DeallocBufferCallbackFn = - std::function; - -/// Callback function type used to insert copy from original subview to -/// subview of the promoted region for the read operands/subview of promoted -/// region to original subview for the results. The copy has to happen from -/// `src` to `dst`. -using CopyCallbackFn = - std::function; - -struct LinalgPromotionOptions { - /// Indices of subViews to promote. If `std::nullopt`, try to promote all - /// operands. - std::optional> operandsToPromote; - LinalgPromotionOptions &setOperandsToPromote(ArrayRef operands) { - operandsToPromote = DenseSet(); - operandsToPromote->insert(operands.begin(), operands.end()); - return *this; - } - /// If ith element of `useFullTiles` is true the full view should be used - /// for the promoted buffer of the ith operand in `operandsToPromote`. - /// Otherwise the partial view will be used. The decision is defaulted to - /// `useFullTileBuffersDefault` when `useFullTileBuffers` is None and for - /// operands missing from `useFullTileBuffers`. - std::optional useFullTileBuffers; - LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef useFullTiles) { - unsigned size = useFullTiles.size(); - llvm::SmallBitVector tmp(size, false); - for (unsigned i = 0; i < size; ++i) - tmp[i] = useFullTiles[i]; - useFullTileBuffers = tmp; - return *this; - } - /// If true all operands unspecified by `useFullTileBuffers` will use the - /// full view, otherwise the partial view. - bool useFullTileBuffersDefault = false; - LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { - useFullTileBuffersDefault = use; - return *this; - } - /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment. - std::optional alignment; - LinalgPromotionOptions &setAlignment(unsigned align) { - alignment = align; - return *this; - } - /// Use alloca with the default allocation scheme. - bool useAlloca = false; - LinalgPromotionOptions &setUseAlloca(bool use) { - useAlloca = use; - return *this; - } - /// Callback function to do the allocation of the promoted buffer. If - /// std::nullopt, then the default allocation scheme of allocating a - /// memref buffer followed by a view operation is used. - std::optional allocationFn; - std::optional deallocationFn; - LinalgPromotionOptions & - setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, - DeallocBufferCallbackFn const &deallocFn) { - allocationFn = allocFn; - deallocationFn = deallocFn; - return *this; - } - /// Callback function to do the copy of data to and from the promoted - /// subview. If std::nullopt then a memref.copy is used. - std::optional copyInFn; - std::optional copyOutFn; - LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, - CopyCallbackFn const ©Out) { - copyInFn = copyIn; - copyOutFn = copyOut; - return *this; - } -}; - /// Create a new buffer using the `allocationFn` provided. The size of this /// buffer is the smallest constant bounding size along each dimension that /// can be computed for the size of the result of `subView`. Returns the @@ -444,27 +519,6 @@ FailureOr linalgOpToAffineLoops(PatternRewriter &rewriter, LinalgOp linalgOp); -//===----------------------------------------------------------------------===// -// Preconditions that ensure the corresponding transformation succeeds and can -// be applied as a rewrite pattern. -//===----------------------------------------------------------------------===// -/// Promote memref.subviews feeding linalg-on-buffers operations. -LogicalResult promoteSubviewsPrecondition(Operation *op, - LinalgPromotionOptions options); - -/// Return success if the operation can be vectorized. -LogicalResult -vectorizeLinalgOpPrecondition(LinalgOp linalgOp, - ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false); - -//===----------------------------------------------------------------------===// -// Transformations exposed as rewrite patterns. -//===----------------------------------------------------------------------===// - -using TileSizeComputationFunction = - std::function(OpBuilder &, Operation *)>; - /// Creates a number of ranges equal to the number of non-zero in `tileSizes`. /// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument /// has one entry per surrounding loop. It uses zero as the convention that a @@ -655,137 +709,218 @@ SmallVectorImpl &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); -struct LinalgPaddingOptions { - /// A padding value for every operand. - SmallVector paddingValues; - LinalgPaddingOptions &setPaddingValues(ArrayRef pv) { - paddingValues.assign(pv.begin(), pv.end()); - return *this; - } - /// A list of iterator dimensions to pad. - SmallVector paddingDimensions; - LinalgPaddingOptions &setPaddingDimensions(ArrayRef pd) { - paddingDimensions.assign(pd.begin(), pd.end()); - return *this; - } - /// A flag for every operand to mark the PadOp as nofold which enables - /// packing for statically shaped operands. - SmallVector packPaddings; - LinalgPaddingOptions &setPackPaddings(ArrayRef pp) { - packPaddings.assign(pp.begin(), pp.end()); - return *this; - } - /// A number of loops to hoist the PadOp out for every operand. - SmallVector hoistPaddings; - LinalgPaddingOptions &setHoistPaddings(ArrayRef hp) { - hoistPaddings.assign(hp.begin(), hp.end()); - return *this; - } - /// A permutation vector for every operand used to transpose the packed - /// PadOp results. - SmallVector> transposePaddings; - LinalgPaddingOptions & - setTransposePaddings(ArrayRef> tp) { - transposePaddings.assign(tp.begin(), tp.end()); - return *this; - } -}; - -struct LinalgTilingAndFusionOptions { - /// Tile sizes used to tile the root operation. - SmallVector tileSizes; - LinalgTilingAndFusionOptions &setTileSizes(ArrayRef ts) { - tileSizes.assign(ts.begin(), ts.end()); - return *this; - } - /// Tile interchange used to permute the tile loops. - SmallVector tileInterchange; - /// When specified, specifies distribution of generated tile loops to - /// processors. - std::optional tileDistribution; - LinalgTilingAndFusionOptions & - setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { - tileDistribution = std::move(distributionOptions); - return *this; - } +/// Apply transformation to split the single linalg op reduction into a +/// parallel and reduction dimension. Then create a new linalg.generic op +/// doing the rest of the reduction. Return the new linalg op with an extra +/// parallel dimension or failure if the transformation didn't happen. +/// +/// Example: +/// ``` +/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, +/// affine_map<(d0) -> ()>], +/// iterator_types = ["reduction"]} +/// ins(%in : tensor<32xf32>) +/// outs(%out : tensor) { +/// ^bb0(%arg1: f32, %arg2: f32): +/// %y = arith.addf %arg1, %arg2 : f32 +/// linalg.yield %y : f32 +/// } -> tensor +/// ``` +/// To: +/// ``` +/// %cst = arith.constant 0.000000e+00 : f32 +/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into +/// tensor<4x8xf32> %1 = tensor.empty [4] : tensor<4xf32> %2 = linalg.fill +/// ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> %3 = +/// linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, +/// affine_map<(d0, d1) -> (d0)>], +/// iterator_types = ["parallel", "reduction"]} +/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { +/// ^bb0(%arg3: f32, %arg5: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor<4xf32> +/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, +/// affine_map<(d0) -> ()>], +/// iterator_types = ["reduction"]} +/// ins(%3 : tensor<4xf32>) outs(%out : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor +/// ``` +struct SplitReductionResult { + Operation *initOrAlloc; + FillOp fillOp; + LinalgOp splitLinalgOp; + LinalgOp resultCombiningLinalgOp; }; +FailureOr +splitReduction(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn, + bool useAlloc = false); -struct LinalgTilingOptions { - /// Computation function that returns the tile sizes for each operation. - /// Delayed construction of constant tile sizes should occur to interoperate - /// with folding. - TileSizeComputationFunction tileSizeComputationFunction = nullptr; - - LinalgTilingOptions & - setTileSizeComputationFunction(TileSizeComputationFunction fun) { - tileSizeComputationFunction = std::move(fun); - return *this; - } - /// Set the `tileSizeComputationFunction` to return the values `ts`. The - /// values must not fold away when tiling. Otherwise, use a more robust - /// `tileSizeComputationFunction`. - LinalgTilingOptions &setTileSizes(const SmallVector &ts) { - tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; - return *this; - } - /// Convenience function to set the `tileSizeComputationFunction` to a - /// function that computes tile sizes at the point they are needed. Allows - /// proper interaction with folding. - LinalgTilingOptions &setTileSizes(ArrayRef ts); - - /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. - /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. - LinalgTilingOptions &scalarizeDynamicDims(); - - /// The interchange vector to reorder the tiled loops. - SmallVector interchangeVector = {}; +/// Scaling-based implementation of the split reduction transformation. +/// Instead of introducing an ExpandShapeOp, this rewrites a reduction +/// dimension `k` into `k * scale + kk`. +/// +/// Example: +/// ``` +/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) +/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> +/// ``` +/// +/// Is transformed to: +/// +/// ``` +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> +/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> +/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> +/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32> +/// %cst = arith.constant 0.000000e+00 : f32 +/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> +/// tensor<16x32x64xf32> +/// %2 = tensor.empty [64, 4] : tensor<64x4xi1> +/// +/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], +/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, +/// tensor<64x4xi1>) +/// outs(%1 : tensor<16x32x64xf32>) { +/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): +/// %5 = arith.mulf %arg3, %arg4 : f32 +/// %6 = arith.addf %arg6, %5 : f32 +/// linalg.yield %6 : f32 +/// } -> tensor<16x32x64xf32> +/// +/// %4 = linalg.generic {indexing_maps = [#map4, #map5], +/// iterator_types = ["parallel", "parallel", "reduction"]} +// ins(%3 : tensor<16x32x64xf32>) +/// outs(%C : tensor<16x32xf32>) { +/// ^bb0(%arg3: f32, %arg4: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor<16x32xf32> +/// +/// return %4 : tensor<16x32xf32> +/// ``` +FailureOr +splitReductionByScaling(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn, + bool useAlloc = false); - LinalgTilingOptions &setInterchange(ArrayRef interchange) { - interchangeVector.assign(interchange.begin(), interchange.end()); - return *this; - } +/// Collapses dimensions of linalg.generic operation. It also collapses inputs +/// before the op and expands outputs after the op. +FailureOr> collapseGenericOpIterationDims( + GenericOp genericOp, ArrayRef foldedIterationDims, + RewriterBase &rewriter); - /// The type of tile loops to generate. - LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; +/// Struct to hold the result of a `pack` call. +struct PackResult { + SmallVector packOps; + linalg::LinalgOp packedLinalgOp; + SmallVector unPackOps; +}; +/// Implement packing of a single LinalgOp by `packedSizes`. +/// There must be one packedSizes entry per `linalgOp` iterator. +/// Return the packed Linalg op on success, failure otherwise. +FailureOr pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + ArrayRef packedSizes); - LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { - loopType = lt; - return *this; - } +/// Struct to hold the result of a `packTranspose` call. +struct PackTransposeResult { + tensor::PackOp transposedPackOp; + linalg::LinalgOp transposedLinalgOp; + tensor::UnPackOp transposedUnPackOp; +}; +/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the +/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements. +/// Return failure if either: +/// 1. the `packOp` does not have the `linalgOp` as its unique use. +/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied +/// to the unique `packOp` use. +/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of +/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty. +FailureOr +packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, + linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, + ArrayRef outerPerm, ArrayRef innerPerm); - /// When specified, specifies distribution of generated tile loops to - /// processors. - std::optional distribution; +/// Rewrite tensor.from_elements to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::FromElementsOp fromElementsOp); - LinalgTilingOptions & - setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { - distribution = std::move(distributionOptions); - return *this; - } +/// Rewrite tensor.generate to linalg.generic. +FailureOr +rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::GenerateOp generateOp); - /// Specification markers of how to distribute the `linalg.tiled_loop`. - SmallVector distributionTypes = {}; +/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + tensor::PadOp padOp); - LinalgTilingOptions &setDistributionTypes(ArrayRef types) { - distributionTypes.assign(types.begin(), types.end()); - return *this; - } +/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +/// and linalg.matmul. +/// +/// A convolution operation can be written as a matrix-matrix multiplication by +/// unfolding the cross-correlation between input and filter and explicitly copy +/// overlapped sliding window inputs. +/// +/// Consider 2D input X with single channel input and output and 2x2 filter W: +/// [x(0, 0) , x(0, 1) , ..., x(0, n) ] +/// [x(1, 0) , x(1, 1) , ..., x(1, n) ] +/// [. , . ,. , . ] [w(0, 0), w(0, 1)] +/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] +/// [. , . , ., . ] +/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] +/// +/// The packed input data (img2col) is a matrix with |rows| = output spatial +/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need +/// to calculate the dot product between filter window at input X(x, y)) and the +/// filter which will look like the following where r.h.s is the img2col matrix +/// and l.h.s is the flattened filter: +/// +/// [x(0,0), x(0,1), x(1,0), x(1,1)] +/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] +/// [x(0,1), x(1,1), x(0,2), x(1,2)] +/// [ . , . , . , . ] +/// +/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter +/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix +/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in +/// the N input. For the case where N > 1 its a batched matrix-matrix +/// multiplication. +/// +/// On success, return both the operation that produces the img2col tensor and +/// the final operation of the sequence that replaces the original convolution. +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); - /// Peel the specified loops. - SmallVector peeledLoops; +/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no +/// reduction among the input channels so each convolution can be a +/// matrix-vector product and by transposing both input filter so channels are +/// outer most the computation is a batched matrix-vector product. +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, + linalg::DepthwiseConv2DNhwcHwcOp convOp); - LinalgTilingOptions &setPeeledLoops(ArrayRef loops) { - peeledLoops.clear(); - peeledLoops.append(loops.begin(), loops.end()); - return *this; - } -}; +/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the +/// channels are to the left of the image shape dimensions, the position of the +/// contraction dimension in the resulting matmul is reversed. This swaps the +/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) * +/// (C x Kh x Kw, Ho x Wo)) +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); -/// Canonicalization patterns relevant to apply after tiling patterns. These -/// are applied automatically by the tiling pass but need to be applied -/// manually when tiling is called programmatically. -RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); -void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Rewrite patterns wrapping transformations. +// TODO: every single such pattern should be a close to noop wrapper around a +// functional-stye API call. +//===----------------------------------------------------------------------===// /// /// Linalg padding pattern. @@ -797,15 +932,8 @@ LinalgPaddingOptions options = LinalgPaddingOptions(), PatternBenefit benefit = 1); - /// `matchAndRewrite` implementation that returns the significant - /// transformed pieces of IR. - FailureOr returningMatchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const; - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } + PatternRewriter &rewriter) const override; private: /// Options to control padding and hoisting. @@ -884,89 +1012,6 @@ PatternRewriter &rewriter) const override; }; -/// Return vector::CombiningKind for the given op. -std::optional getCombinerOpKind(Operation *combinerOp); - -//===----------------------------------------------------------------------===// -// Transformations exposed as rewrite patterns. -//===----------------------------------------------------------------------===// - -/// Linalg generalization patterns - -/// Populates `patterns` with patterns to convert spec-generated named ops to -/// linalg.generic ops. -void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); - -/// Linalg decompose convolutions patterns - -/// Populates patterns to decompose high-D convolution ops into low-D ones. -/// This is a step in progressive lowering for convolution ops, afterwards we -/// can vectorize the low-D convolution ops. -void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Populates patterns to transform linalg.conv_2d_xxx operations into -/// linalg.generic (for img2col packing) and linalg.matmul. -/// \see rewriteInIm2Col for more details. -void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); - -/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) -/// and linalg.matmul. -/// -/// A convolution operation can be written as a matrix-matrix multiplication by -/// unfolding the cross-correlation between input and filter and explicitly copy -/// overlapped sliding window inputs. -/// -/// Consider 2D input X with single channel input and output and 2x2 filter W: -/// [x(0, 0) , x(0, 1) , ..., x(0, n) ] -/// [x(1, 0) , x(1, 1) , ..., x(1, n) ] -/// [. , . ,. , . ] [w(0, 0), w(0, 1)] -/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] -/// [. , . , ., . ] -/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] -/// -/// The packed input data (img2col) is a matrix with |rows| = output spatial -/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need -/// to calculate the dot product between filter window at input X(x, y)) and the -/// filter which will look like the following where r.h.s is the img2col matrix -/// and l.h.s is the flattned filter: -/// -/// [x(0,0), x(0,1), x(1,0), x(1,1)] -/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] -/// [x(0,1), x(1,1), x(0,2), x(1,2)] -/// [ . , . , . , . ] -/// -/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter -/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix -/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in -/// the N input. For the case where N > 1 its a batched matrxi-matrix -/// multplication. -/// -/// On success, return both the operation that produces the img2col tensor and -/// the final operation of the sequence that replaces the original convolution. -FailureOr> -rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); - -/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no -/// reduction among the input channels so each convolution can be a -/// matrix-vector product and by transposing both input filter so channels are -/// outer most the computation is a batched matrix-vector product. -FailureOr> -rewriteInIm2Col(RewriterBase &rewriter, - linalg::DepthwiseConv2DNhwcHwcOp convOp); - -/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the -/// channels are to the left of the image shape dimensions, the position of the -/// contraction dimension in the resulting matmul is reversed. This swaps the -/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) * -/// (C x Kh x Kw, Ho x Wo)) -FailureOr> -rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); - -//===----------------------------------------------------------------------===// -// Op-specific patterns. -//===----------------------------------------------------------------------===// - /// tensor::PadOp is not canonicalized away yet, so we provide a /// transformation to `linalg.generic`. struct PadOpTransformationPattern : public OpRewritePattern { @@ -976,18 +1021,6 @@ PatternRewriter &rewriter) const override; }; -/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands -/// to a static bounding box. Use `paddingValues` and `packPaddings` to set -/// padding value and nofold attribute of the created tensor::PadOps, -/// respectively. Update `paddedOp` to the cloned operation with statically -/// shaped `paddingDimensions` and return the extracted dynamically shaped -/// results. If padding fails, return failure. -FailureOr> -rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp); - using OptimizeCopyFn = std::function; @@ -1030,18 +1063,6 @@ PatternRewriter &rewriter) const override; }; -/// Populates `patterns` with patterns that vectorize tensor.pad. -/// These patterns are meant to apply in a complementary fashion. Benefits -/// are used to encode a certain ordering of pattern application. To avoid -/// scattering magic constants throughout the code base, the patterns must be -/// added with this function. `baseBenefit` can be used to offset the benefit -/// of all tensor::PadOp vectorization patterns by a certain value. -void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, - PatternBenefit baseBenefit = 1); - -void populateExtractOpVectorizationPatterns(RewritePatternSet &patterns, - PatternBenefit baseBenefit = 1); - /// Match and rewrite for the pattern: /// ``` /// %alloc = ... @@ -1127,183 +1148,162 @@ ControlFn controlFn; }; -/// Split Reduction options. -struct SplitReductionOptions { - // Ratio used to split the reduction dimension. If the ratio is <= 1, - // nothing will be done. - int64_t ratio = 0; - // Index where the extra dimension is added to the intermediate tensor - // shape. - unsigned index = 0; - // If the inner dimension after splitting is parallel or reduction. - bool innerParallel = false; -}; +//===----------------------------------------------------------------------===// +// Populate functions. +//===----------------------------------------------------------------------===// -/// Function signature to control reduction splitting. This returns -/// `SplitReductionOptions`. -// TODO: don't use unsigned unless doing bit manipulation. -using ControlSplitReductionFn = - std::function; +/// Canonicalization patterns relevant to apply after tiling patterns. These +/// are applied automatically by the tiling pass but need to be applied +/// manually when tiling is called programmatically. +RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); +void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); -/// Patterns to apply `splitReduction` below. -void populateSplitReductionPattern( +/// Linalg generalization patterns + +/// Populates `patterns` with patterns to convert spec-generated named ops to +/// linalg.generic ops. +void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); + +/// Linalg decompose convolutions patterns + +/// Populates patterns to decompose high-D convolution ops into low-D ones. +/// This is a step in progressive lowering for convolution ops, afterwards we +/// can vectorize the low-D convolution ops. +void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populates patterns to transform linalg.conv_2d_xxx operations into +/// linalg.generic (for img2col packing) and linalg.matmul. +/// \see rewriteInIm2Col for more details. +void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); + +void populatePadTensorTilingPatterns(RewritePatternSet &patterns, + const LinalgTilingOptions &options); + +/// Populates `patterns` with patterns that vectorize tensor.pad. +/// These patterns are meant to apply in a complementary fashion. Benefits +/// are used to encode a certain ordering of pattern application. To avoid +/// scattering magic constants throughout the code base, the patterns must be +/// added with this function. `baseBenefit` can be used to offset the benefit +/// of all tensor::PadOp vectorization patterns by a certain value. +void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + +void populateExtractOpVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + +/// Populate patterns for splitting a `LinalgOp` with multiple statements within +/// its payload into multiple `GenericOp` that have a single statement. +/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments +/// and results from the generated decomposed ops. This is default `true` since +/// the core decomposition patterns relies on these clean up patterns. It is set +/// to false only for testing purposes. +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, + bool removeDeadArgsAndResults = true); + +/// Populate patterns that convert non-destination-style ops to destination +/// style ops. +void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); + +/// Populate patterns for vectorizing low-D convolution ops. This is a step in +/// progressive lowering for convolution ops, it assume high-D convolution ops +/// were decomposed previously. +void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate patterns that convert `ElementwiseMappable` ops to linalg +/// parallel loops. +void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); + +/// Populate patterns that are only useful in the context of sparse tensors. +void populateSparseTensorRewriting(RewritePatternSet &patterns); + +/// Function type which is used to control when to stop fusion. It is expected +/// that OpOperand is not modified in the callback. The OpOperand is not marked +/// as const to allow callers to use non-const methods. +using ControlFusionFn = std::function; + +/// Patterns for fusing linalg operation on tensors. + +/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations +/// when both operations are fusable elementwise operations. +void populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, - const ControlSplitReductionFn &controlSplitReductionFn, - bool useAlloc = false); + const ControlFusionFn &controlElementwiseOpFusion); -/// Apply transformation to split the single linalg op reduction into a -/// parallel and reduction dimension. Then create a new linalg.generic op -/// doing the rest of the reduction. Return the new linalg op with an extra -/// parallel dimension or failure if the transformation didn't happen. -/// -/// Example: -/// ``` -/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, -/// affine_map<(d0) -> ()>], -/// iterator_types = ["reduction"]} -/// ins(%in : tensor<32xf32>) -/// outs(%out : tensor) { -/// ^bb0(%arg1: f32, %arg2: f32): -/// %y = arith.addf %arg1, %arg2 : f32 -/// linalg.yield %y : f32 -/// } -> tensor -/// ``` -/// To: -/// ``` -/// %cst = arith.constant 0.000000e+00 : f32 -/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into -/// tensor<4x8xf32> %1 = tensor.empty [4] : tensor<4xf32> %2 = linalg.fill -/// ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> %3 = -/// linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, -/// affine_map<(d0, d1) -> (d0)>], -/// iterator_types = ["parallel", "reduction"]} -/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { -/// ^bb0(%arg3: f32, %arg5: f32): -/// %5 = arith.addf %arg3, %arg4 : f32 -/// linalg.yield %5 : f32 -/// } -> tensor<4xf32> -/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, -/// affine_map<(d0) -> ()>], -/// iterator_types = ["reduction"]} -/// ins(%3 : tensor<4xf32>) outs(%out : tensor) { -/// ^bb0(%arg3: f32, %arg4: f32): -/// %5 = arith.addf %arg3, %arg4 : f32 -/// linalg.yield %5 : f32 -/// } -> tensor -/// ``` -struct SplitReductionResult { - Operation *initOrAlloc; - FillOp fillOp; - LinalgOp splitLinalgOp; - LinalgOp resultCombiningLinalgOp; -}; -FailureOr -splitReduction(PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn, - bool useAlloc = false); +/// Patterns to bubble up or down data layout ops across other operations. +void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns); -/// Scaling-based implementation of the split reduction transformation. -/// Instead of introducing an ExpandShapeOp, this rewrites a reduction -/// dimension `k` into `k * scale + kk`. -/// -/// Example: -/// ``` -/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) -/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> -/// ``` -/// -/// Is transformed to: -/// -/// ``` -/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> -/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> -/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> -/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32> -/// %cst = arith.constant 0.000000e+00 : f32 -/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> -/// tensor<16x32x64xf32> -/// %2 = tensor.empty [64, 4] : tensor<64x4xi1> -/// -/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], -/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, -/// tensor<64x4xi1>) -/// outs(%1 : tensor<16x32x64xf32>) { -/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): -/// %5 = arith.mulf %arg3, %arg4 : f32 -/// %6 = arith.addf %arg6, %5 : f32 -/// linalg.yield %6 : f32 -/// } -> tensor<16x32x64xf32> -/// -/// %4 = linalg.generic {indexing_maps = [#map4, #map5], -/// iterator_types = ["parallel", "parallel", "reduction"]} -// ins(%3 : tensor<16x32x64xf32>) -/// outs(%C : tensor<16x32xf32>) { -/// ^bb0(%arg3: f32, %arg4: f32): -/// %5 = arith.addf %arg3, %arg4 : f32 -/// linalg.yield %5 : f32 -/// } -> tensor<16x32xf32> -/// -/// return %4 : tensor<16x32xf32> -/// ``` -FailureOr -splitReductionByScaling(PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn, - bool useAlloc = false); +/// Pattern to remove dead operands and results of `linalg.generic` operations. +/// This is effectively DCE for a linalg op. +void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); -/// Collapses dimensions of linalg.generic operation. It also collapses inputs -/// before the op and expands outputs after the op. -FailureOr> collapseGenericOpIterationDims( - GenericOp genericOp, ArrayRef foldedIterationDims, - RewriterBase &rewriter); +/// Patterns to promote inputs to outputs and remove unused inputs of +/// `linalg.generic` ops. +void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); -/// Struct to hold the result of a `pack` call. -struct PackResult { - SmallVector packOps; - linalg::LinalgOp packedLinalgOp; - SmallVector unPackOps; -}; -/// Implement packing of a single LinalgOp by `packedSizes`. -/// There must be one packedSizes entry per `linalgOp` iterator. -/// Return the packed Linalg op on success, failure otherwise. -FailureOr pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - ArrayRef packedSizes); +/// Function type to control generic op dimension collapsing. It is expected +/// to return an array of `ReassociationIndices` representing dimensions that +/// should be merged. +using GetCollapsableDimensionsFn = + std::function(linalg::GenericOp)>; -/// Struct to hold the result of a `packTranspose` call. -struct PackTransposeResult { - tensor::PackOp transposedPackOp; - linalg::LinalgOp transposedLinalgOp; - tensor::UnPackOp transposedUnPackOp; -}; -/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the -/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements. -/// Return failure if either: -/// 1. the `packOp` does not have the `linalgOp` as its unique use. -/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied -/// to the unique `packOp` use. -/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of -/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty. -FailureOr -packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, - linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, - ArrayRef outerPerm, ArrayRef innerPerm); +/// Pattern to collapse dimensions in a linalg.generic op. This will collapse +/// tensor operands when needed and expand back the result tensors. +void populateCollapseDimensions( + RewritePatternSet &patterns, + const GetCollapsableDimensionsFn &controlCollapseDimensions); -/// Rewrite tensor.from_elements to linalg.generic. -FailureOr -rewriteInDestinationPassingStyle(RewriterBase &rewriter, - tensor::FromElementsOp fromElementsOp); +/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its +/// producer (consumer) generic operation by expanding the dimensionality of the +/// loop in the generic op. +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); -/// Rewrite tensor.generate to linalg.generic. -FailureOr -rewriteInDestinationPassingStyle(RewriterBase &rewriter, - tensor::GenerateOp generateOp); +/// Patterns to fold an expanding tensor.expand_shape operation with its +/// producer generic operation by collapsing the dimensions of the generic op. +void populateFoldReshapeOpsByCollapsingPatterns( + RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); -/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. -FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, - tensor::PadOp padOp); +/// Patterns to constant fold Linalg operations. +void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, + const ControlFusionFn &controlFn); + +/// Pattern to fuse a `tensor.pad` operation with the producer of its source, +/// if the producer is a `linalg` operation with all parallel iterator types. +void populateFuseTensorPadWithProducerLinalgOpPatterns( + RewritePatternSet &patterns); + +/// Patterns to convert from one named op to another. These can be seen as +/// canonicalizations of named ops into another named op. +void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); + +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors via reassociative reshape ops. +void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns); + +/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on +/// tensors via rank-reducing slices. +void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns); + +/// A pattern that converts init operands to input operands. +void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); + +/// Patterns that are used to inline constant operands into linalg generic ops. +void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); + +/// Patterns that are used to bubble up extract slice op above linalg op. +void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); + +/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into +/// linalg.fill(%cst, tensor.extract_slice(%init)). +void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); + +/// Patterns to apply `splitReduction` below. +void populateSplitReductionPattern( + RewritePatternSet &patterns, + const ControlSplitReductionFn &controlSplitReductionFn, + bool useAlloc = false); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -70,11 +70,11 @@ /// } /// ``` /// -/// After loop peeling, this function tries to simplify/canonicalize affine.min -/// and affine.max ops in the body of the peeled loop and in the body of the -/// partial iteration loop, taking advantage of the fact that the peeled loop -/// has only "full" iterations. This canonicalization is expected to enable -/// further canonicalization opportunities through other patterns. +/// After loop peeling, this function tries to simplify affine.min and +/// affine.max ops in the body of the peeled loop and in the body of the partial +/// iteration loop, taking advantage of the fact that the peeled loop has only +/// "full" iterations. This simplification is expected to enable further +/// canonicalization opportunities through other patterns. /// /// The return value indicates whether the loop was rewritten or not. Loops are /// not rewritten if: @@ -85,8 +85,8 @@ /// Note: This function rewrites the given scf.for loop in-place and creates a /// new scf.for operation for the last iteration. It replaces all uses of the /// unpeeled loop with the results of the newly generated scf.for. -LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, - scf::ForOp &partialIteration); +LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, + scf::ForOp &partialIteration); /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1726,8 +1726,8 @@ paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); - FailureOr result = - tryApply(target, paddingOptions); + IRRewriter rewriter(target->getContext()); + FailureOr result = padLinalgOp(rewriter, target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -46,26 +46,6 @@ #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DBGSNL() (llvm::dbgs() << "\n") -//===----------------------------------------------------------------------===// -// Transformations exposed as rewrite patterns. -//===----------------------------------------------------------------------===// - -LinalgTilingOptions & -mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { - assert(!tileSizeComputationFunction && "tile sizes already set"); - SmallVector tileSizes(ts.begin(), ts.end()); - tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().getBody().front()); - return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { - Value v = b.create(op->getLoc(), s); - return v; - })); - }; - return *this; -} - /// Pad the `opOperand` in the `paddingDimensions` using the padding value and /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. /// Exit early and return the `opOperand` value if the shape dimensions that @@ -170,6 +150,19 @@ opOperand->get(), paddingValue, nofold); } +static SmallVector +getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, + utils::IteratorType::parallel); +} + +//===----------------------------------------------------------------------===// +// Transformations exposed as functional-style API calls. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// rewriteAsPaddedOp transformation. +//===----------------------------------------------------------------------===// FailureOr> linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, ArrayRef paddingDimensions, @@ -227,15 +220,20 @@ return paddedSubviewResults; } -/// Try to peel a loop `op` and return the new result. +//===----------------------------------------------------------------------===// +// peelLoop transformation. +//===----------------------------------------------------------------------===// + +/// Try to peel and canonicalize loop `op` and return the new result. +/// Also applies affine_min/max bounds simplification on the fly where relevant. // TODO: Add support for scf.parallel and affine.for loops. SmallVector mlir::linalg::peelLoop(RewriterBase &rewriter, Operation *op) { return llvm::TypeSwitch>(op) .Case([&](scf::ForOp forOp) { scf::ForOp partialIteration; - if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, - partialIteration))) + if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, + partialIteration))) return partialIteration->getResults(); assert(!partialIteration && "expected that loop was not peeled"); return forOp->getResults(); @@ -243,24 +241,24 @@ .Default([&](Operation *op) { return op->getResults(); }); } -/// Peel and canonicalize 'loops'. +/// Peel 'loops' and applies affine_min/max bounds simplification on the fly +/// where relevant. void mlir::linalg::peelLoops(RewriterBase &rewriter, ArrayRef loops) { for (auto loopOp : loops) peelLoop(rewriter, loopOp); } -/// Linalg padding pattern. -mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( - MLIRContext *context, LinalgPaddingOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} +//===----------------------------------------------------------------------===// +// pad transformation. +//===----------------------------------------------------------------------===// -FailureOr -mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { +FailureOr mlir::linalg::padLinalgOp(RewriterBase &rewriter, + LinalgOp linalgOp, + LinalgPaddingOptions options) { if (!linalgOp.hasTensorSemantics()) - return failure(); + return rewriter.notifyMatchFailure( + linalgOp, "only applies to Linalg ops with tensor semantics"); // Pad the operation. LinalgOp paddedOp; @@ -268,7 +266,8 @@ rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, options.paddingValues, options.packPaddings, paddedOp); if (failed(newResults)) - return failure(); + return rewriter.notifyMatchFailure(linalgOp, + "failed to rewrite as a padded op"); // Hoist the padding. for (const auto &en : enumerate(options.hoistPaddings)) { @@ -276,12 +275,17 @@ break; OpOperand &opOperand = paddedOp->getOpOperand(en.index()); auto padOp = opOperand.get().getDefiningOp(); - if (!padOp || en.value() == 0) + if (!padOp || en.value() == 0) { + (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip"); continue; + } // Fail hoisting if the operand shape is not fully static. - if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) - return failure(); + if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) { + (void)rewriter.notifyMatchFailure(linalgOp, + "non static padding shape -- skip"); + continue; + } tensor::PadOp hoistedOp; SmallVector transposeOps; @@ -292,8 +296,11 @@ FailureOr newResult = hoistPaddingOnTensors( padOp, en.value(), transposeVector, hoistedOp, transposeOps); - if (failed(newResult)) + if (failed(newResult)) { + (void)rewriter.notifyMatchFailure(linalgOp, + "failed to apply hoistPadding"); continue; + } rewriter.replaceOp(padOp, *newResult); } @@ -303,1026 +310,1052 @@ return paddedOp; } -LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( - memref::CopyOp copyOp, PatternRewriter &rewriter) const { - return vectorizeCopy(rewriter, copyOp); +//===----------------------------------------------------------------------===// +// pack transformation. +//===----------------------------------------------------------------------===// + +#ifndef NDEBUG +/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). +static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { + bool found = false; + for (AffineExpr e : map.getResults()) { + if (!e.isFunctionOfDim(dim)) + continue; + if (found) + return false; + found = true; + } + return true; } +#endif // NDEBUG -static SmallVector -getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, - utils::IteratorType::parallel); +/// Return the index of the first result of `map` that is a function of +/// AffineDimExpr(dim), std::nullopt otherwise. +static std::optional getFirstResultIndexFunctionOf(AffineMap map, + int64_t dim) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + AffineExpr expr = map.getResult(i); + if (!expr.isFunctionOfDim(dim)) + continue; + return i; + } + return std::nullopt; } -/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to -/// initialize with pad_val) and GenericOp (to copy contents). -LogicalResult -PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const { +/// Perform one step of packing of a LinalgOp's metadata along `dim` into the +/// `newDim` at `iteratorTypes.size()` by: +/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. +/// 2. Appending a `newDim` to the domain of every indexing map. +/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing +/// by potentially adding a `newDim` result to `map`. +/// The preserved invariant is that `iteratorTypes.size()` is always equal to +/// `map.getNumDims()` for every map in `indexingMaps`. +/// +/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. +/// Return a vector that records the optional packing for each operand. +/// Return failure if the packed indexing cannot be represented with a LinalgOp. +/// +/// Further details: +/// ================ +/// The current implementation of packing (i.e. data tiling) consists of +/// rewriting a linearized strip-mined form into a higher-dimensional access. +/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite +/// `I` into `4 * i + ii`, where `0 <= ii < 4`. +/// The access is further rewritten as `A[i][f(j, k, l)][ii]`. +/// +/// This rewrite into higher dimensional access is not possible for general +/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: +/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we +/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. +/// The rewrite of the access would be a form not representable in Linalg: +/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. +/// Note however that as `J` and `ii` iterate, the accesses do not have a +/// particular alignment, so packing does not achieve alignment in this case +/// +/// In the future, we may want to consider a mixed-form that allows some +/// alignment in the presence of multiple accesses: +/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` +/// And would rewrite accesses as: +/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` +static FailureOr>> +packLinalgMetadataOnce(SmallVectorImpl &indexingMaps, + SmallVectorImpl &iteratorTypes, + int64_t dim) { + int64_t newDim = iteratorTypes.size(); + iteratorTypes.push_back(iteratorTypes[dim]); - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + SmallVector> packedDimPerIndexingMap( + indexingMaps.size(), std::nullopt); + SmallVector newMaps; + for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; + ++operandIdx) { + AffineMap map = indexingMaps[operandIdx]; - // Bail on non-static shapes. - if (!inputShapedType.hasStaticShape()) - return failure(); - if (!resultShapedType.hasStaticShape()) - return failure(); + // Add the `newDim` to map whatever the case. + assert(map.getNumDims() == newDim && "num dims invariant violation"); + map = map.shiftDims(1, newDim); - // Only support padding with a constant for now, i.e. either: - // 1. A BBarg from a different block. - // 2. A value defined outside of the current block. - Block &block = padOp.getRegion().front(); - auto yieldOp = cast(block.getTerminator()); - Value padValue = yieldOp.getValue(); - Operation *definingOp = padValue.getDefiningOp(); - if (definingOp && definingOp->getBlock() == &block) - return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) - return failure(); + // Get the at-most-1 index of the result that is a function of `dim`. + // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which + // logically chunks dimension `dim` into `K * dim + newDim`, where the + // packing factor `K` is specified separately. + assert(hasAtMostOneResultFunctionOfDim(map, dim) && + "num results invariant violation"); + auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); + if (!maybeOperandDimensionToPack.has_value()) { + newMaps.push_back(map); + continue; + } - // Create tensor with the padded shape - Location loc = padOp.getLoc(); - SmallVector indices(resultShapedType.getRank(), - rewriter.create(loc, 0)); - Value emptyTensor = rewriter.create( - loc, resultShapedType.getShape(), resultShapedType.getElementType()); + // We can only pack AffineDimExpr atm. + if (!map.getResult(maybeOperandDimensionToPack.value()) + .isa()) + return failure(); - // Initialize tensor with the pad value - Value tmpTensor = rewriter - .create(loc, ValueRange{padValue}, - ValueRange{emptyTensor}) - .result(); + // Add `newDim` to the results of the map. + map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), + map.getNumResults()); + newMaps.push_back(map); - // Copy original contents into new tensor - // Uses linalg.generic, but could be done with tensor.insert_slice - SmallVector outputExprs; - for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { - outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + - padOp.getStaticLow()[i]); + // Record the that `operandIdx` is packed. + packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; } + indexingMaps = newMaps; - SmallVector transferMaps = { - rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), - AffineMap::get(resultShapedType.getRank(), - /*symbolCount=*/0, outputExprs, rewriter.getContext())}; + return packedDimPerIndexingMap; +} - rewriter.replaceOpWithNewOp( - padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps, - getNParallelLoopsAttrs(resultShapedType.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }); +namespace { - return success(); -} +/// Helper struct to encode packing along one dimension of a LinalgOp. +struct PackedOperandsDim { + OpFoldResult packedSize; + SmallVector> packedDimForEachOperand; +}; -/// Filling `dest` using FillOp constant padding value if possible. -/// Otherwise, generate a tensor::GenerateOp. -Value GeneralizePadOpPattern::createFillOrGenerateOp( - PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, - const SmallVector &dynSizes) const { - auto padValue = padOp.getConstantPaddingValue(); - if (padValue) - return rewriter.create(padOp.getLoc(), padValue, dest).result(); +/// Helper struct to encode packing along all dimensions of a LinalgOp. +struct PackedOperandsDimList { + void push_back(PackedOperandsDim &&packedOperandsDims) { + spec.emplace_back(packedOperandsDims); + } + /// Return all the dims that have been packed for operand @ `operandPos`. + SmallVector extractPackedDimsForOperand(int64_t operandPos); + /// Return all the pack sizes by which an operand @ `operandPos` is packed. + SmallVector extractPackSizesForOperand(int64_t operandPos); - // Fill could not be optimized: Lower to tensor::GenerateOp with region. - auto generateOp = rewriter.create( - padOp.getLoc(), padOp.getResultType(), dynSizes); - // Copy region to new op. - IRMapping bvm; - padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); - return generateOp; +private: + SmallVector spec; +}; + +} // namespace + +SmallVector +PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { + SmallVector res; + for (int64_t i = 0, e = spec.size(); i < e; ++i) { + if (!spec[i].packedDimForEachOperand[operandPos].has_value()) + continue; + res.push_back(spec[i].packedDimForEachOperand[operandPos].value()); + } + return res; } -LogicalResult -GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const { - // Given an OpFoldResult, return an index-typed value. - auto getIdxValue = [&](OpFoldResult ofr) { - if (auto val = ofr.dyn_cast()) - return val; - return rewriter - .create( - padOp.getLoc(), ofr.get().cast().getInt()) - .getResult(); - }; - - auto resultType = padOp.getResultType(); - // Compute size of EmptyOp. Any combination of static/dynamic is supported. - SmallVector dynSizes; - SmallVector staticSizes; - for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { - if (resultType.isDynamicDim(dim)) { - auto srcSize = rewriter.createOrFold( - padOp.getLoc(), padOp.getSource(), dim); - // Add low and high padding value. - auto plusLow = rewriter.createOrFold( - padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); - auto plusHigh = rewriter.createOrFold( - padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); - dynSizes.push_back(plusHigh); - } - staticSizes.push_back(resultType.getDimSize(dim)); +SmallVector +PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { + SmallVector res; + for (int64_t i = 0, e = spec.size(); i < e; ++i) { + if (!spec[i].packedDimForEachOperand[operandPos].has_value()) + continue; + res.push_back(spec[i].packedSize); } + return res; +} - // Init tensor and fill it with padding. - Value emptyTensor = rewriter.create( - padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); - Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); - - // Try optimize the copy of source. - if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) - return success(); - - // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead - // for copying the PadOp source. - auto sourceType = padOp.getSourceType(); - // Compute size of source of tensor::PadOp. - SmallVector srcSizes; - for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { - if (sourceType.isDynamicDim(dim)) { - srcSizes.push_back(rewriter.createOrFold( - padOp.getLoc(), padOp.getSource(), dim)); - } else { - srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); - } +/// Implement packing of a single LinalgOp by performing packing by +/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. +/// Return the packed Linalg op on success, failure otherwise. +FailureOr linalg::pack(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + ArrayRef packedSizes) { + if (packedSizes.size() != linalgOp.getNumLoops()) { + return rewriter.notifyMatchFailure(linalgOp, + "incorrect number of pack sizes"); } - // Strides of InsertSliceOp are all 1. - SmallVector strides(sourceType.getRank(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, - strides); - - return success(); -} -LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( - tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { - if (!sliceOp.hasUnitStride()) - return failure(); + Location loc = linalgOp->getLoc(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; + llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); + llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); + DBGSNL();); - auto padOp = sliceOp.getSource().getDefiningOp(); - if (!padOp) - return failure(); + SmallVector packOps; + SmallVector unPackOps; + // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. + PackedOperandsDimList listOfPackedOperandsDim; + for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { + std::optional maybeConstant = getConstantIntValue(packedSizes[i]); + // Skip tile sizes explicitly set to 0. + if (maybeConstant.has_value() && maybeConstant.value() == 0) + continue; - bool zeroSliceGuard = true; - if (controlFn) { - if (std::optional control = controlFn(sliceOp)) - zeroSliceGuard = *control; - else + PackedOperandsDim packedOperandsDims; + packedOperandsDims.packedSize = packedSizes[i]; + FailureOr>> + maybePackedDimForEachOperand = + packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); + if (failed(maybePackedDimForEachOperand)) return failure(); - } + packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; + listOfPackedOperandsDim.push_back(std::move(packedOperandsDims)); - Operation *tiledPadOp = - tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), zeroSliceGuard); - // All shapes are static and the data source is actually used. Rewrite into - // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); - return success(); -} + LLVM_DEBUG( + DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] + << "\n"; + llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); + llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); + llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, + DBGS() << "packedDimForEachOperand: "); + DBGSNL();); + } -/// Returns a tensor.pad op if padding value is set. Otherwise, returns the -/// source directly. The method assumes that the `packOp` has static shapes. -static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, - tensor::PackOp packOp) { - Value input = packOp.getSource(); - if (!packOp.getPaddingValue()) { - return input; + // Step 2. Propagate packing to all LinalgOp operands. + SmallVector inputsAndInits, results; + for (auto operandsList : + {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) { + for (OpOperand *opOperandPtr : operandsList) { + int64_t pos = opOperandPtr->getOperandNumber(); + Value operand = opOperandPtr->get(); + SmallVector innerPos = + listOfPackedOperandsDim.extractPackedDimsForOperand(pos); + SmallVector innerPackSizes = + listOfPackedOperandsDim.extractPackSizesForOperand(pos); + LLVM_DEBUG( + DBGS() << "operand: " << operand << "\n"; + llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); + llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); + DBGSNL();); + if (innerPackSizes.empty()) { + inputsAndInits.push_back(operand); + continue; + } + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, operand, innerPackSizes, innerPos, + /*outerDimsPerm=*/{}); + // TODO: value of the padding attribute should be determined by consumers. + Attribute zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + packOps.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes, zero)); + inputsAndInits.push_back(packOps.back()); + } } - Location loc = packOp.getLoc(); - ShapedType inputType = packOp.getSourceType(); - int64_t inputRank = inputType.getRank(); - assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), - [](int64_t val) { return val == 1; })); + // Step 3. Build the packed op, use the type of `inits` as result types. + ValueRange inputs = + ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); + ValueRange inits = + ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); + auto packedLinalgOp = rewriter.create( + linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, + iteratorTypes); + packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); - SmallVector paddedShape; - DenseMap tileAndPosMapping = - packOp.getDimAndTileMapping(); - for (int64_t dim = 0; dim < inputRank; ++dim) { - int64_t size = inputType.getDimSize(dim); - if (!tileAndPosMapping.count(dim)) { - paddedShape.push_back(size); + // Step 4. Propagate packing to all the op results. + for (OpResult result : packedLinalgOp->getResults()) { + int64_t resultNum = result.getResultNumber(); + tensor::PackOp maybePackedInit = + inits[resultNum].getDefiningOp(); + if (!maybePackedInit) { + results.push_back(result); continue; } - - // The size is less than or equal to tileSize because outer dims are all 1s. - std::optional tileSize = - getConstantIntValue(tileAndPosMapping.lookup(dim)); - assert(tileSize.has_value() && "dynamic inner tile size is not supported"); - paddedShape.push_back(tileSize.value()); + // Build the symmetrical UnPackOp to the existing PackOp. + unPackOps.push_back(rewriter.create( + packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), + maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); + results.push_back(unPackOps.back()); } - auto resultType = - RankedTensorType::get(paddedShape, inputType.getElementType()); - return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), - /*nofold=*/false, loc, builder); -} -static SmallVector -getPackUnpackNormalizedInnerPerm(int rank, ArrayRef innerDimsPos) { - constexpr int64_t kNonTiledMarker = -1; - SmallVector vec(rank, kNonTiledMarker); - for (auto [index, value] : llvm::enumerate(innerDimsPos)) - vec[value] = index; - SmallVector perm = llvm::to_vector(llvm::make_filter_range( - vec, [&](int64_t v) { return v != kNonTiledMarker; })); - return perm; -} + // Step 5. Replace `linalgOp`. + rewriter.replaceOp(linalgOp, results); -LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( - tensor::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s A - // tensor.expand_shape will be generated in this case. - int64_t srcRank = packOp.getSourceRank(); - if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank), - [](int64_t val) { return val != 1; })) { - return rewriter.notifyMatchFailure( - packOp, "require the outer dimension of the result are all 1s"); - } + // Return packedLinalgOp. + return PackResult{packOps, + cast(packedLinalgOp.getOperation()), + unPackOps}; +} - if (llvm::any_of(packOp.getMixedTiles(), - [](OpFoldResult tile) { return tile.is(); })) { - return rewriter.notifyMatchFailure(packOp, - "require inner tile sizes being static"); - } +//===----------------------------------------------------------------------===// +// packTranspose transformation. +//===----------------------------------------------------------------------===// - // 1. Use rank-reduced tensor.extract_slice op to extract the tile. - Location loc = packOp.getLoc(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); - SmallVector readOffsets(srcRank, zeroIdxAttr); - SmallVector readStrides(srcRank, oneIdxAttr); - SmallVector readSizes; - SmallVector readShape; - DenseMap dimAndTileMapping = - packOp.getDimAndTileMapping(); - for (auto i : llvm::seq(0, srcRank)) { - if (!dimAndTileMapping.count(i)) { - readSizes.push_back(oneIdxAttr); - continue; - } - readSizes.push_back(dimAndTileMapping[i]); - readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) - .value_or(ShapedType::kDynamic)); - } - Type elemType = packOp.getSourceType().getElementType(); - auto readType = RankedTensorType::get(readShape, elemType); +/// Return a copy of `tensorType` after permutation by `permutationVector`. +// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder +// but this would introduce a dependence on Dialect in IR. +// TODO: Restructure. +static RankedTensorType permuteShape(RankedTensorType tensorType, + ArrayRef permutationVector) { + SmallVector shape(tensorType.getShape()); + applyPermutationToVector(shape, permutationVector); + return RankedTensorType::Builder(tensorType).setShape(shape); +} - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - Value tile = rewriter.create( - loc, readType, input, readOffsets, readSizes, readStrides); +/// Return a new GenericOp obtained by transposing opOperand by the permutation +/// vector: +/// - the corresponding indexing map is transposed by `permutation` +/// - the corresponding operand value is replaced by `transposedValue` +/// `linalgOp` is replaced by the return op in the process. +/// Asserts that `transposedValue` is of the proper transposed ShapedType. +static LinalgOp transposeOneLinalgOperandAndReplace( + RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, + ArrayRef permutation, Value transposedValue) { + // Sanity check the operand. + assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); - // 2. Transpose the tile to match the inner tile order. - SmallVector perm = - getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos()); - // The permutation is inverted when normalizing so invert back to match the - // ordering in the pack op. - perm = invertPermutationVector(perm); + // Sanity check of the expected transposed tensor type. + auto tensorType = permuteShape( + opOperand.get().getType().cast(), permutation); + (void)tensorType; + assert(tensorType == transposedValue.getType() && + "expected tensor type mismatch"); - LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; - llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); + // Compute the transposed indexing map. + // Sigh unsigned pollution. + SmallVector tmpTransposition = llvm::to_vector( + llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); + AffineMap permutationMap = + AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); + AffineMap transposedMap = + permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); - SmallVector transpShape = readShape; - applyPermutationToVector(transpShape, perm); + // Set the transposed indexing map in the proper position. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; + // Set the transposedValue in the proper operand position. + SmallVector operands = linalgOp->getOperands(); + operands[opOperand.getOperandNumber()] = transposedValue; - Value empty = rewriter.create(loc, transpShape, elemType); - auto transposedOp = - rewriter.create(loc, tile, empty, perm); + ValueRange operandsRef(operands); + auto transposedGenericOp = rewriter.create( + /*location=*/linalgOp->getLoc(), + /*resultTensorTypes=*/ + operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), + /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), + /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); + transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); + rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); - // 3. Insert the inner tile to the destination. - int64_t destRank = packOp.getDestRank(); - SmallVector writeStrides(destRank, oneIdxAttr); - SmallVector writeOffsets(destRank, zeroIdxAttr); - SmallVector writeSizes(srcRank, oneIdxAttr); - for (auto size : transpShape) - writeSizes.push_back(rewriter.getIndexAttr(size)); + return cast(transposedGenericOp.getOperation()); +} - auto insert = rewriter.create( - loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, - writeSizes, writeStrides); - rewriter.replaceOp(packOp, insert.getResult()); +FailureOr +linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, + linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, + ArrayRef outerPerm, + ArrayRef innerPerm) { + Location loc = linalgOp.getLoc(); - return success(); -} + // Step 1. Transpose packOp. + rewriter.setInsertionPoint(packOp); + tensor::PackOp transposedPackOp = + packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); -LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( - tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { - int64_t srcRank = unpackOp.getSourceRank(); - int64_t destRank = unpackOp.getDestRank(); - ArrayRef srcShape = unpackOp.getSourceType().getShape(); - if (llvm::any_of(srcShape.take_front(destRank), - [](int64_t val) { return val != 1; })) { + if (!packOp.getResult().hasOneUse()) + return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); + + OpOperand &packUse = *packOp->getUses().begin(); + if (packUse.getOwner() != linalgOp) { return rewriter.notifyMatchFailure( - unpackOp, "require the outer dimension of the result are all 1s"); + linalgOp, "not a single use by the LinalgOp target"); + } + if (maybeUnPackOp && + (!linalgOp.isDpsInit(&packUse) || + maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { + return rewriter.notifyMatchFailure(linalgOp, + "not produced by the LinalgOp target"); } - // 1. Use rank-reduced tensor.extract_slice op to extract the tile. - Location loc = unpackOp.getLoc(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); - SmallVector readOffsets(srcRank, zeroIdxAttr); - SmallVector readStrides(srcRank, oneIdxAttr); - - auto mixedTiles = unpackOp.getMixedTiles(); - SmallVector readSizes(destRank, oneIdxAttr); - readSizes.append(mixedTiles.begin(), mixedTiles.end()); + // Step 2. Transpose linalgOp. + // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the + // identity. Don't rely on it. + int64_t numLeadingDims = packOp.getSourceRank(); + int64_t numTrailingDims = packOp.getInnerDimsPos().size(); + // Step 2.a. Compute the permutation on the whole operand. + // Leading part just reuse the outerPerm. + SmallVector permutation(outerPerm); + if (permutation.empty()) + llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); + // Trailing part needs to reindex positions by `numLeadingDims`. + if (innerPerm.empty()) { + llvm::append_range( + permutation, + llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); + } else { + llvm::append_range(permutation, + llvm::map_range(innerPerm, [&](int64_t pos) { + return numLeadingDims + pos; + })); + } + if (!isPermutationVector(permutation)) + return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); - // Explicitly create the type for extract_slice op because the inner tile - // size could be 1. We want to represent the whole inner tile in this case. - ArrayRef readShape = srcShape.drop_front(destRank); - Type elemType = unpackOp.getSourceType().getElementType(); - auto readType = RankedTensorType::get(readShape, elemType); - Value innerTile = rewriter.create( - loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides); + // Step 2.b. Save the transposedPackUse operand number in case we need to + // get the tied OpResult after `linalgOp` has been replaced. + int64_t packUseOperandNumber = packUse.getOperandNumber(); + // Step 2.c. Actually perform the transposition. + rewriter.setInsertionPoint(linalgOp); + linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( + rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); - // 2. Transpose the tile to match the outer corresponding tile order. - ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); - SmallVector perm = - getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos); - SmallVector transpShape(readShape); - applyPermutationToVector(transpShape, perm); + // Step 3. Maybe transpose unPackOp. + tensor::UnPackOp transposedUnPackOp; + if (maybeUnPackOp) { + OpOperand &opOperand = + transposedLinalgOp->getOpOperand(packUseOperandNumber); + OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); + rewriter.setInsertionPoint(maybeUnPackOp); + transposedUnPackOp = maybeUnPackOp.createTransposedClone( + rewriter, loc, transposedResult, innerPerm, outerPerm); - Value empty = rewriter.create(loc, transpShape, elemType); - auto transposedOp = - rewriter.create(loc, innerTile, empty, perm); + rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); + } - // 3. Handle in-complete tiles if needed. It truncates trailing data from the - // transposed tile. - int numLoops = transpShape.size(); - SmallVector tileStrides(numLoops, oneIdxAttr); - SmallVector tileOffsets(numLoops, zeroIdxAttr); - SmallVector tileSizes; - for (int dim : innerDimsPos) - tileSizes.push_back(getAsOpFoldResult( - rewriter.createOrFold(loc, unpackOp.getDest(), dim))); + // Step 4. Finally, replace packOp now that we don't need it anymore. + rewriter.replaceOp(packOp, transposedPackOp->getResults()); - applyPermutationToVector(tileSizes, perm); - auto partialTile = rewriter.create( - loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); + return PackTransposeResult{transposedPackOp, transposedLinalgOp, + transposedUnPackOp}; +} - // 4. Insert the result to the destination tensor. - SmallVector writeSizes; - SmallVector writeStrides(destRank, oneIdxAttr); - SmallVector writeOffsets(destRank, zeroIdxAttr); - DenseMap dimAndTileMapping = - unpackOp.getDimAndTileMapping(); - for (int i = 0, idx = 0; i < destRank; ++i) { - if (dimAndTileMapping.count(i)) - writeSizes.push_back(tileSizes[idx++]); - else - writeSizes.push_back(oneIdxAttr); - } - auto insert = rewriter.create( - loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, - writeStrides); - rewriter.replaceOp(unpackOp, insert.getResult()); +//===----------------------------------------------------------------------===// +// Transformations exposed as rewrite patterns. +//===----------------------------------------------------------------------===// - return success(); +LinalgTilingOptions & +mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { + assert(!tileSizeComputationFunction && "tile sizes already set"); + SmallVector tileSizes(ts.begin(), ts.end()); + tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().getBody().front()); + return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { + Value v = b.create(op->getLoc(), s); + return v; + })); + }; + return *this; } -// The following are patterns for downscaling convolution ops with size-1 -// window dimensions. -// -// Note that we'd eventually want to write such transformations in a generic -// way, e.g., converting to linalg.generic, removing the size-1 dimensions, -// and then turning back to named ops. But for now it's fine to have a few -// patterns matching special ops to get started. +/// Linalg padding pattern. -template -FailureOr DownscaleSizeOneWindowed2DConvolution:: - returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { - if (convOp.hasBufferSemantics()) - return failure(); // To be implemented. +mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( + MLIRContext *context, LinalgPaddingOptions options, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} - Value input = convOp.getInputs().front(); - Value kernel = convOp.getInputs().back(); - Value output = convOp.getOutputs().front(); +LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( + LinalgOp op, PatternRewriter &rewriter) const { + return padLinalgOp(rewriter, op, options); +} - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); +LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( + memref::CopyOp copyOp, PatternRewriter &rewriter) const { + return vectorizeCopy(rewriter, copyOp); +} - auto kernelShape = kernelType.getShape(); - auto outputShape = outputType.getShape(); +/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to +/// initialize with pad_val) and GenericOp (to copy contents). +LogicalResult +PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const { - // Get domain indices based on conv2D layout. - auto [khIndex, kwIndex, ohIndex, owIndex] = - TypeSwitch>( - convOp) - .Case([&](linalg::Conv2DNhwcHwcfOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::Conv2DNchwFchwOp op) { - return std::make_tuple(2, 3, 2, 3); - }) - .Case([&](linalg::PoolingNhwcSumOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::PoolingNchwSumOp op) { - return std::make_tuple(0, 1, 2, 3); - }) - .Case([&](linalg::PoolingNhwcMaxOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::PoolingNhwcMinOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { - return std::make_tuple(0, 1, 1, 2); - }) - .Case([&](linalg::PoolingNchwMaxOp op) { - return std::make_tuple(0, 1, 2, 3); - }) - .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d/pool2d operation."); - return std::make_tuple(0, 0, 0, 0); - }); + auto inputShapedType = padOp.getSource().getType().cast(); + auto resultShapedType = padOp.getResult().getType().cast(); - // Only handle the case where at least one of the window dimensions is - // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; - int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; - bool removeH = (khSize == 1 && ohSize == 1); - bool removeW = (kwSize == 1 && owSize == 1); - if (!removeH && !removeW) + // Bail on non-static shapes. + if (!inputShapedType.hasStaticShape()) + return failure(); + if (!resultShapedType.hasStaticShape()) return failure(); - // Get new shapes and types for all operands by removing the size-1 - // dimension. - using RTTBuilder = RankedTensorType::Builder; - RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); - RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); - RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); + // Only support padding with a constant for now, i.e. either: + // 1. A BBarg from a different block. + // 2. A value defined outside of the current block. + Block &block = padOp.getRegion().front(); + auto yieldOp = cast(block.getTerminator()); + Value padValue = yieldOp.getValue(); + Operation *definingOp = padValue.getDefiningOp(); + if (definingOp && definingOp->getBlock() == &block) + return failure(); + if (!definingOp && padValue.cast().getOwner() == &block) + return failure(); - // Rank-reduce operands. - Location loc = convOp.getLoc(); - Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, input, newInputType); - Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, kernel, newKernelType); - Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, output, newOutputType); + // Create tensor with the padded shape + Location loc = padOp.getLoc(); + SmallVector indices(resultShapedType.getRank(), + rewriter.create(loc, 0)); + Value emptyTensor = rewriter.create( + loc, resultShapedType.getShape(), resultShapedType.getElementType()); - // Rank-reduce strides and dilations too. - // TODO: dropDim 1-liner helper. - auto strides = - llvm::to_vector<4>(convOp.getStrides().template getValues()); - strides.erase(strides.begin() + (removeH ? 0 : 1)); - auto stridesAttr = rewriter.getI64VectorAttr(strides); + // Initialize tensor with the pad value + Value tmpTensor = rewriter + .create(loc, ValueRange{padValue}, + ValueRange{emptyTensor}) + .result(); - auto dilations = - llvm::to_vector<4>(convOp.getDilations().template getValues()); - dilations.erase(dilations.begin() + (removeH ? 0 : 1)); - auto dilationsAttr = rewriter.getI64VectorAttr(dilations); + // Copy original contents into new tensor + // Uses linalg.generic, but could be done with tensor.insert_slice + SmallVector outputExprs; + for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { + outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + + padOp.getStaticLow()[i]); + } - auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newKernel}, - ValueRange{newOutput}, stridesAttr, dilationsAttr); + SmallVector transferMaps = { + rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), + AffineMap::get(resultShapedType.getRank(), + /*symbolCount=*/0, outputExprs, rewriter.getContext())}; - // Insert back. - Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, conv1DOp.getResult(0), output); - rewriter.replaceOp(convOp, inserted); + rewriter.replaceOpWithNewOp( + padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps, + getNParallelLoopsAttrs(resultShapedType.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); - return conv1DOp; + return success(); } -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution< - PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; -template struct linalg::DownscaleSizeOneWindowed2DConvolution< - PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; -template struct linalg::DownscaleSizeOneWindowed2DConvolution; - -FailureOr -DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( - DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { - if (convOp.hasBufferSemantics()) - return failure(); // To be implemented. +/// Filling `dest` using FillOp constant padding value if possible. +/// Otherwise, generate a tensor::GenerateOp. +Value GeneralizePadOpPattern::createFillOrGenerateOp( + PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, + const SmallVector &dynSizes) const { + auto padValue = padOp.getConstantPaddingValue(); + if (padValue) + return rewriter.create(padOp.getLoc(), padValue, dest).result(); - Value input = convOp.getInputs().front(); - Value kernel = convOp.getInputs().back(); - Value output = convOp.getOutputs().front(); + // Fill could not be optimized: Lower to tensor::GenerateOp with region. + auto generateOp = rewriter.create( + padOp.getLoc(), padOp.getResultType(), dynSizes); + // Copy region to new op. + IRMapping bvm; + padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); + return generateOp; +} - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); +LogicalResult +GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const { + // Given an OpFoldResult, return an index-typed value. + auto getIdxValue = [&](OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + return rewriter + .create( + padOp.getLoc(), ofr.get().cast().getInt()) + .getResult(); + }; - auto kernelShape = kernelType.getShape(); - auto outputShape = outputType.getShape(); + auto resultType = padOp.getResultType(); + // Compute size of EmptyOp. Any combination of static/dynamic is supported. + SmallVector dynSizes; + SmallVector staticSizes; + for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { + if (resultType.isDynamicDim(dim)) { + auto srcSize = rewriter.createOrFold( + padOp.getLoc(), padOp.getSource(), dim); + // Add low and high padding value. + auto plusLow = rewriter.createOrFold( + padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); + auto plusHigh = rewriter.createOrFold( + padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); + dynSizes.push_back(plusHigh); + } + staticSizes.push_back(resultType.getDimSize(dim)); + } - // Only handle the case where at least one of the window dimensions is - // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; - int64_t ohSize = outputShape[1], owSize = outputShape[2]; - bool removeH = (khSize == 1 && ohSize == 1); - bool removeW = (kwSize == 1 && owSize == 1); - if (!removeH && !removeW) - return failure(); + // Init tensor and fill it with padding. + Value emptyTensor = rewriter.create( + padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); + Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); - // Get new shapes and types for all operands by removing the size-1 - // dimension. - using RTTBuilder = RankedTensorType::Builder; - RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); - RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + // Try optimize the copy of source. + if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) + return success(); - // Rank-reduce operands. - Location loc = convOp.getLoc(); - Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, input, newInputType); - Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, kernel, newKernelType); - Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, output, newOutputType); + // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead + // for copying the PadOp source. + auto sourceType = padOp.getSourceType(); + // Compute size of source of tensor::PadOp. + SmallVector srcSizes; + for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { + if (sourceType.isDynamicDim(dim)) { + srcSizes.push_back(rewriter.createOrFold( + padOp.getLoc(), padOp.getSource(), dim)); + } else { + srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); + } + } + // Strides of InsertSliceOp are all 1. + SmallVector strides(sourceType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, + strides); - // Rank-reduce strides and dilations too. - // TODO: dropDim 1-liner helper. - auto strides = llvm::to_vector<4>(convOp.getStrides().getValues()); - strides.erase(strides.begin() + (removeH ? 0 : 1)); - auto stridesAttr = rewriter.getI64VectorAttr(strides); + return success(); +} - auto dilations = - llvm::to_vector<4>(convOp.getDilations().getValues()); - dilations.erase(dilations.begin() + (removeH ? 0 : 1)); - auto dilationsAttr = rewriter.getI64VectorAttr(dilations); +LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( + tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { + if (!sliceOp.hasUnitStride()) + return failure(); - auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newKernel}, - ValueRange{newOutput}, stridesAttr, dilationsAttr); + auto padOp = sliceOp.getSource().getDefiningOp(); + if (!padOp) + return failure(); - // Insert back. - Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, conv1DOp.getResult(0), output); - rewriter.replaceOp(convOp, inserted); + bool zeroSliceGuard = true; + if (controlFn) { + if (std::optional control = controlFn(sliceOp)) + zeroSliceGuard = *control; + else + return failure(); + } - return conv1DOp; + Operation *tiledPadOp = + tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), zeroSliceGuard); + // All shapes are static and the data source is actually used. Rewrite into + // pad(extract_slice(x)). + rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); + return success(); } -void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add, - DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), - benefit); - patterns.add< - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution, - DownscaleSizeOneWindowed2DConvolution>( - patterns.getContext(), benefit); -} +/// Returns a tensor.pad op if padding value is set. Otherwise, returns the +/// source directly. The method assumes that the `packOp` has static shapes. +static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, + tensor::PackOp packOp) { + Value input = packOp.getSource(); + if (!packOp.getPaddingValue()) { + return input; + } -//===----------------------------------------------------------------------===// -// pack transformation. -//===----------------------------------------------------------------------===// + Location loc = packOp.getLoc(); + ShapedType inputType = packOp.getSourceType(); + int64_t inputRank = inputType.getRank(); + assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), + [](int64_t val) { return val == 1; })); -#ifndef NDEBUG -/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). -static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { - bool found = false; - for (AffineExpr e : map.getResults()) { - if (!e.isFunctionOfDim(dim)) + SmallVector paddedShape; + DenseMap tileAndPosMapping = + packOp.getDimAndTileMapping(); + for (int64_t dim = 0; dim < inputRank; ++dim) { + int64_t size = inputType.getDimSize(dim); + if (!tileAndPosMapping.count(dim)) { + paddedShape.push_back(size); continue; - if (found) - return false; - found = true; - } - return true; -} -#endif // NDEBUG + } -/// Return the index of the first result of `map` that is a function of -/// AffineDimExpr(dim), std::nullopt otherwise. -static std::optional getFirstResultIndexFunctionOf(AffineMap map, - int64_t dim) { - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - AffineExpr expr = map.getResult(i); - if (!expr.isFunctionOfDim(dim)) - continue; - return i; + // The size is less than or equal to tileSize because outer dims are all 1s. + std::optional tileSize = + getConstantIntValue(tileAndPosMapping.lookup(dim)); + assert(tileSize.has_value() && "dynamic inner tile size is not supported"); + paddedShape.push_back(tileSize.value()); } - return std::nullopt; + auto resultType = + RankedTensorType::get(paddedShape, inputType.getElementType()); + return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), + /*nofold=*/false, loc, builder); } -/// Perform one step of packing of a LinalgOp's metadata along `dim` into the -/// `newDim` at `iteratorTypes.size()` by: -/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. -/// 2. Appending a `newDim` to the domain of every indexing map. -/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing -/// by potentially adding a `newDim` result to `map`. -/// The preserved invariant is that `iteratorTypes.size()` is always equal to -/// `map.getNumDims()` for every map in `indexingMaps`. -/// -/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. -/// Return a vector that records the optional packing for each operand. -/// Return failure if the packed indexing cannot be represented with a LinalgOp. -/// -/// Further details: -/// ================ -/// The current implementation of packing (i.e. data tiling) consists of -/// rewriting a linearized strip-mined form into a higher-dimensional access. -/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite -/// `I` into `4 * i + ii`, where `0 <= ii < 4`. -/// The access is further rewritten as `A[i][f(j, k, l)][ii]`. -/// -/// This rewrite into higher dimensional access is not possible for general -/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: -/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we -/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. -/// The rewrite of the access would be a form not representable in Linalg: -/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. -/// Note however that as `J` and `ii` iterate, the accesses do not have a -/// particular alignment, so packing does not achieve alignment in this case -/// -/// In the future, we may want to consider a mixed-form that allows some -/// alignment in the presence of multiple accesses: -/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` -/// And would rewrite accesses as: -/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` -static FailureOr>> -packLinalgMetadataOnce(SmallVectorImpl &indexingMaps, - SmallVectorImpl &iteratorTypes, - int64_t dim) { - int64_t newDim = iteratorTypes.size(); - iteratorTypes.push_back(iteratorTypes[dim]); +static SmallVector +getPackUnpackNormalizedInnerPerm(int rank, ArrayRef innerDimsPos) { + constexpr int64_t kNonTiledMarker = -1; + SmallVector vec(rank, kNonTiledMarker); + for (auto [index, value] : llvm::enumerate(innerDimsPos)) + vec[value] = index; + SmallVector perm = llvm::to_vector(llvm::make_filter_range( + vec, [&](int64_t v) { return v != kNonTiledMarker; })); + return perm; +} - SmallVector> packedDimPerIndexingMap( - indexingMaps.size(), std::nullopt); - SmallVector newMaps; - for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; - ++operandIdx) { - AffineMap map = indexingMaps[operandIdx]; +LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( + tensor::PackOp packOp, PatternRewriter &rewriter) const { + // TODO: support the case that outer dimensions are not all 1s A + // tensor.expand_shape will be generated in this case. + int64_t srcRank = packOp.getSourceRank(); + if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank), + [](int64_t val) { return val != 1; })) { + return rewriter.notifyMatchFailure( + packOp, "require the outer dimension of the result are all 1s"); + } - // Add the `newDim` to map whatever the case. - assert(map.getNumDims() == newDim && "num dims invariant violation"); - map = map.shiftDims(1, newDim); + if (llvm::any_of(packOp.getMixedTiles(), + [](OpFoldResult tile) { return tile.is(); })) { + return rewriter.notifyMatchFailure(packOp, + "require inner tile sizes being static"); + } - // Get the at-most-1 index of the result that is a function of `dim`. - // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which - // logically chunks dimension `dim` into `K * dim + newDim`, where the - // packing factor `K` is specified separately. - assert(hasAtMostOneResultFunctionOfDim(map, dim) && - "num results invariant violation"); - auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); - if (!maybeOperandDimensionToPack.has_value()) { - newMaps.push_back(map); + // 1. Use rank-reduced tensor.extract_slice op to extract the tile. + Location loc = packOp.getLoc(); + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); + Attribute oneIdxAttr = rewriter.getIndexAttr(1); + SmallVector readOffsets(srcRank, zeroIdxAttr); + SmallVector readStrides(srcRank, oneIdxAttr); + SmallVector readSizes; + SmallVector readShape; + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + for (auto i : llvm::seq(0, srcRank)) { + if (!dimAndTileMapping.count(i)) { + readSizes.push_back(oneIdxAttr); continue; } - - // We can only pack AffineDimExpr atm. - if (!map.getResult(maybeOperandDimensionToPack.value()) - .isa()) - return failure(); - - // Add `newDim` to the results of the map. - map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), - map.getNumResults()); - newMaps.push_back(map); - - // Record the that `operandIdx` is packed. - packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; + readSizes.push_back(dimAndTileMapping[i]); + readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) + .value_or(ShapedType::kDynamic)); } - indexingMaps = newMaps; + Type elemType = packOp.getSourceType().getElementType(); + auto readType = RankedTensorType::get(readShape, elemType); - return packedDimPerIndexingMap; -} + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); + Value tile = rewriter.create( + loc, readType, input, readOffsets, readSizes, readStrides); -namespace { + // 2. Transpose the tile to match the inner tile order. + SmallVector perm = + getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos()); + // The permutation is inverted when normalizing so invert back to match the + // ordering in the pack op. + perm = invertPermutationVector(perm); -/// Helper struct to encode packing along one dimension of a LinalgOp. -struct PackedOperandsDim { - OpFoldResult packedSize; - SmallVector> packedDimForEachOperand; -}; + LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; + llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); -/// Helper struct to encode packing along all dimensions of a LinalgOp. -struct PackedOperandsDimList { - void push_back(PackedOperandsDim &&packedOperandsDims) { - spec.emplace_back(packedOperandsDims); - } - /// Return all the dims that have been packed for operand @ `operandPos`. - SmallVector extractPackedDimsForOperand(int64_t operandPos); - /// Return all the pack sizes by which an operand @ `operandPos` is packed. - SmallVector extractPackSizesForOperand(int64_t operandPos); + SmallVector transpShape = readShape; + applyPermutationToVector(transpShape, perm); -private: - SmallVector spec; -}; + Value empty = rewriter.create(loc, transpShape, elemType); + auto transposedOp = + rewriter.create(loc, tile, empty, perm); -} // namespace + // 3. Insert the inner tile to the destination. + int64_t destRank = packOp.getDestRank(); + SmallVector writeStrides(destRank, oneIdxAttr); + SmallVector writeOffsets(destRank, zeroIdxAttr); + SmallVector writeSizes(srcRank, oneIdxAttr); + for (auto size : transpShape) + writeSizes.push_back(rewriter.getIndexAttr(size)); -SmallVector -PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { - SmallVector res; - for (int64_t i = 0, e = spec.size(); i < e; ++i) { - if (!spec[i].packedDimForEachOperand[operandPos].has_value()) - continue; - res.push_back(spec[i].packedDimForEachOperand[operandPos].value()); - } - return res; -} + auto insert = rewriter.create( + loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, + writeSizes, writeStrides); + rewriter.replaceOp(packOp, insert.getResult()); -SmallVector -PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { - SmallVector res; - for (int64_t i = 0, e = spec.size(); i < e; ++i) { - if (!spec[i].packedDimForEachOperand[operandPos].has_value()) - continue; - res.push_back(spec[i].packedSize); - } - return res; + return success(); } -/// Implement packing of a single LinalgOp by performing packing by -/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. -/// Return the packed Linalg op on success, failure otherwise. -FailureOr linalg::pack(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - ArrayRef packedSizes) { - if (packedSizes.size() != linalgOp.getNumLoops()) { - return rewriter.notifyMatchFailure(linalgOp, - "incorrect number of pack sizes"); +LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( + tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { + int64_t srcRank = unpackOp.getSourceRank(); + int64_t destRank = unpackOp.getDestRank(); + ArrayRef srcShape = unpackOp.getSourceType().getShape(); + if (llvm::any_of(srcShape.take_front(destRank), + [](int64_t val) { return val != 1; })) { + return rewriter.notifyMatchFailure( + unpackOp, "require the outer dimension of the result are all 1s"); } - Location loc = linalgOp->getLoc(); - SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - SmallVector iteratorTypes = - linalgOp.getIteratorTypesArray(); - LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; - llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); - llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); - DBGSNL();); + // 1. Use rank-reduced tensor.extract_slice op to extract the tile. + Location loc = unpackOp.getLoc(); + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); + Attribute oneIdxAttr = rewriter.getIndexAttr(1); + SmallVector readOffsets(srcRank, zeroIdxAttr); + SmallVector readStrides(srcRank, oneIdxAttr); - SmallVector packOps; - SmallVector unPackOps; - // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. - PackedOperandsDimList listOfPackedOperandsDim; - for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { - std::optional maybeConstant = getConstantIntValue(packedSizes[i]); - // Skip tile sizes explicitly set to 0. - if (maybeConstant.has_value() && maybeConstant.value() == 0) - continue; + auto mixedTiles = unpackOp.getMixedTiles(); + SmallVector readSizes(destRank, oneIdxAttr); + readSizes.append(mixedTiles.begin(), mixedTiles.end()); - PackedOperandsDim packedOperandsDims; - packedOperandsDims.packedSize = packedSizes[i]; - FailureOr>> - maybePackedDimForEachOperand = - packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); - if (failed(maybePackedDimForEachOperand)) - return failure(); - packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; - listOfPackedOperandsDim.push_back(std::move(packedOperandsDims)); + // Explicitly create the type for extract_slice op because the inner tile + // size could be 1. We want to represent the whole inner tile in this case. + ArrayRef readShape = srcShape.drop_front(destRank); + Type elemType = unpackOp.getSourceType().getElementType(); + auto readType = RankedTensorType::get(readShape, elemType); + Value innerTile = rewriter.create( + loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides); - LLVM_DEBUG( - DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] - << "\n"; - llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); - llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); - llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, - DBGS() << "packedDimForEachOperand: "); - DBGSNL();); - } + // 2. Transpose the tile to match the outer corresponding tile order. + ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); + SmallVector perm = + getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos); + SmallVector transpShape(readShape); + applyPermutationToVector(transpShape, perm); - // Step 2. Propagate packing to all LinalgOp operands. - SmallVector inputsAndInits, results; - for (auto operandsList : - {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) { - for (OpOperand *opOperandPtr : operandsList) { - int64_t pos = opOperandPtr->getOperandNumber(); - Value operand = opOperandPtr->get(); - SmallVector innerPos = - listOfPackedOperandsDim.extractPackedDimsForOperand(pos); - SmallVector innerPackSizes = - listOfPackedOperandsDim.extractPackSizesForOperand(pos); - LLVM_DEBUG( - DBGS() << "operand: " << operand << "\n"; - llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); - llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); - DBGSNL();); - if (innerPackSizes.empty()) { - inputsAndInits.push_back(operand); - continue; - } - Value dest = tensor::PackOp::createDestinationTensor( - rewriter, loc, operand, innerPackSizes, innerPos, - /*outerDimsPerm=*/{}); - // TODO: value of the padding attribute should be determined by consumers. - Attribute zeroAttr = - rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); - Value zero = rewriter.create(loc, zeroAttr); - packOps.push_back(rewriter.create( - loc, operand, dest, innerPos, innerPackSizes, zero)); - inputsAndInits.push_back(packOps.back()); - } - } + Value empty = rewriter.create(loc, transpShape, elemType); + auto transposedOp = + rewriter.create(loc, innerTile, empty, perm); - // Step 3. Build the packed op, use the type of `inits` as result types. - ValueRange inputs = - ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); - ValueRange inits = - ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); - auto packedLinalgOp = rewriter.create( - linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, - iteratorTypes); - packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); + // 3. Handle in-complete tiles if needed. It truncates trailing data from the + // transposed tile. + int numLoops = transpShape.size(); + SmallVector tileStrides(numLoops, oneIdxAttr); + SmallVector tileOffsets(numLoops, zeroIdxAttr); + SmallVector tileSizes; + for (int dim : innerDimsPos) + tileSizes.push_back(getAsOpFoldResult( + rewriter.createOrFold(loc, unpackOp.getDest(), dim))); - // Step 4. Propagate packing to all the op results. - for (OpResult result : packedLinalgOp->getResults()) { - int64_t resultNum = result.getResultNumber(); - tensor::PackOp maybePackedInit = - inits[resultNum].getDefiningOp(); - if (!maybePackedInit) { - results.push_back(result); - continue; - } - // Build the symmetrical UnPackOp to the existing PackOp. - unPackOps.push_back(rewriter.create( - packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), - maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); - results.push_back(unPackOps.back()); + applyPermutationToVector(tileSizes, perm); + auto partialTile = rewriter.create( + loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); + + // 4. Insert the result to the destination tensor. + SmallVector writeSizes; + SmallVector writeStrides(destRank, oneIdxAttr); + SmallVector writeOffsets(destRank, zeroIdxAttr); + DenseMap dimAndTileMapping = + unpackOp.getDimAndTileMapping(); + for (int i = 0, idx = 0; i < destRank; ++i) { + if (dimAndTileMapping.count(i)) + writeSizes.push_back(tileSizes[idx++]); + else + writeSizes.push_back(oneIdxAttr); } + auto insert = rewriter.create( + loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, + writeStrides); + rewriter.replaceOp(unpackOp, insert.getResult()); + + return success(); +} + +// The following are patterns for downscaling convolution ops with size-1 +// window dimensions. +// +// Note that we'd eventually want to write such transformations in a generic +// way, e.g., converting to linalg.generic, removing the size-1 dimensions, +// and then turning back to named ops. But for now it's fine to have a few +// patterns matching special ops to get started. + +template +FailureOr DownscaleSizeOneWindowed2DConvolution:: + returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.getInputs().front(); + Value kernel = convOp.getInputs().back(); + Value output = convOp.getOutputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Get domain indices based on conv2D layout. + auto [khIndex, kwIndex, ohIndex, owIndex] = + TypeSwitch>( + convOp) + .Case([&](linalg::Conv2DNhwcHwcfOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::Conv2DNchwFchwOp op) { + return std::make_tuple(2, 3, 2, 3); + }) + .Case([&](linalg::PoolingNhwcSumOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNchwSumOp op) { + return std::make_tuple(0, 1, 2, 3); + }) + .Case([&](linalg::PoolingNhwcMaxOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMinOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { + return std::make_tuple(0, 1, 1, 2); + }) + .Case([&](linalg::PoolingNchwMaxOp op) { + return std::make_tuple(0, 1, 2, 3); + }) + .Default([&](Operation *op) { + llvm_unreachable("unexpected conv2d/pool2d operation."); + return std::make_tuple(0, 0, 0, 0); + }); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; + int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); - // Step 5. Replace `linalgOp`. - rewriter.replaceOp(linalgOp, results); + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); - // Return packedLinalgOp. - return PackResult{packOps, - cast(packedLinalgOp.getOperation()), - unPackOps}; -} + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = + llvm::to_vector<4>(convOp.getStrides().template getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); -//===----------------------------------------------------------------------===// -// packTranspose transformation. -//===----------------------------------------------------------------------===// + auto dilations = + llvm::to_vector<4>(convOp.getDilations().template getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); -/// Return a copy of `tensorType` after permutation by `permutationVector`. -// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder -// but this would introduce a dependence on Dialect in IR. -// TODO: Restructure. -static RankedTensorType permuteShape(RankedTensorType tensorType, - ArrayRef permutationVector) { - SmallVector shape(tensorType.getShape()); - applyPermutationToVector(shape, permutationVector); - return RankedTensorType::Builder(tensorType).setShape(shape); -} + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); -/// Return a new GenericOp obtained by transposing opOperand by the permutation -/// vector: -/// - the corresponding indexing map is transposed by `permutation` -/// - the corresponding operand value is replaced by `transposedValue` -/// `linalgOp` is replaced by the return op in the process. -/// Asserts that `transposedValue` is of the proper transposed ShapedType. -static LinalgOp transposeOneLinalgOperandAndReplace( - RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, - ArrayRef permutation, Value transposedValue) { - // Sanity check the operand. - assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); - // Sanity check of the expected transposed tensor type. - auto tensorType = permuteShape( - opOperand.get().getType().cast(), permutation); - (void)tensorType; - assert(tensorType == transposedValue.getType() && - "expected tensor type mismatch"); + return conv1DOp; +} - // Compute the transposed indexing map. - // Sigh unsigned pollution. - SmallVector tmpTransposition = llvm::to_vector( - llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); - AffineMap permutationMap = - AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); - AffineMap transposedMap = - permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; - // Set the transposed indexing map in the proper position. - SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; - // Set the transposedValue in the proper operand position. - SmallVector operands = linalgOp->getOperands(); - operands[opOperand.getOperandNumber()] = transposedValue; +FailureOr +DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( + DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. - ValueRange operandsRef(operands); - auto transposedGenericOp = rewriter.create( - /*location=*/linalgOp->getLoc(), - /*resultTensorTypes=*/ - operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), - /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), - /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); - transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); - rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); + Value input = convOp.getInputs().front(); + Value kernel = convOp.getInputs().back(); + Value output = convOp.getOutputs().front(); - return cast(transposedGenericOp.getOperation()); -} + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); -FailureOr -linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, - linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, - ArrayRef outerPerm, - ArrayRef innerPerm) { - Location loc = linalgOp.getLoc(); + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); - // Step 1. Transpose packOp. - rewriter.setInsertionPoint(packOp); - tensor::PackOp transposedPackOp = - packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[1], owSize = outputShape[2]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); - if (!packOp.getResult().hasOneUse()) - return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 1 : 2); - OpOperand &packUse = *packOp->getUses().begin(); - if (packUse.getOwner() != linalgOp) { - return rewriter.notifyMatchFailure( - linalgOp, "not a single use by the LinalgOp target"); - } - if (maybeUnPackOp && - (!linalgOp.isDpsInit(&packUse) || - maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { - return rewriter.notifyMatchFailure(linalgOp, - "not produced by the LinalgOp target"); - } + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); - // Step 2. Transpose linalgOp. - // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the - // identity. Don't rely on it. - int64_t numLeadingDims = packOp.getSourceRank(); - int64_t numTrailingDims = packOp.getInnerDimsPos().size(); - // Step 2.a. Compute the permutation on the whole operand. - // Leading part just reuse the outerPerm. - SmallVector permutation(outerPerm); - if (permutation.empty()) - llvm::append_range(permutation, llvm::seq(0, numLeadingDims)); - // Trailing part needs to reindex positions by `numLeadingDims`. - if (innerPerm.empty()) { - llvm::append_range( - permutation, - llvm::seq(numLeadingDims, numLeadingDims + numTrailingDims)); - } else { - llvm::append_range(permutation, - llvm::map_range(innerPerm, [&](int64_t pos) { - return numLeadingDims + pos; - })); - } - if (!isPermutationVector(permutation)) - return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = llvm::to_vector<4>(convOp.getStrides().getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); - // Step 2.b. Save the transposedPackUse operand number in case we need to - // get the tied OpResult after `linalgOp` has been replaced. - int64_t packUseOperandNumber = packUse.getOperandNumber(); - // Step 2.c. Actually perform the transposition. - rewriter.setInsertionPoint(linalgOp); - linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( - rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); + auto dilations = + llvm::to_vector<4>(convOp.getDilations().getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - // Step 3. Maybe transpose unPackOp. - tensor::UnPackOp transposedUnPackOp; - if (maybeUnPackOp) { - OpOperand &opOperand = - transposedLinalgOp->getOpOperand(packUseOperandNumber); - OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); - rewriter.setInsertionPoint(maybeUnPackOp); - transposedUnPackOp = maybeUnPackOp.createTransposedClone( - rewriter, loc, transposedResult, innerPerm, outerPerm); + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); - rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); - } + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); - // Step 4. Finally, replace packOp now that we don't need it anymore. - rewriter.replaceOp(packOp, transposedPackOp->getResults()); + return conv1DOp; +} - return PackTransposeResult{transposedPackOp, transposedLinalgOp, - transposedUnPackOp}; +void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add, + DownscaleSizeOneWindowed2DConvolution, + DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), + benefit); + patterns.add< + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution>( + patterns.getContext(), benefit); } diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -132,7 +132,7 @@ // "the loop trip count is divisible by the step" // is valid. LogicalResult status = - scf::peelAndCanonicalizeForLoop(rewriter, target, result); + scf::peelForLoopAndSimplifyBounds(rewriter, target, result); // TODO: Return both the peeled loop and the remainder loop. results.push_back(failed(status) ? target : result); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -180,9 +180,9 @@ }); } -LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter, - ForOp forOp, - ForOp &partialIteration) { +LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter, + ForOp forOp, + ForOp &partialIteration) { Value previousUb = forOp.getUpperBound(); Value splitBound; if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound))) @@ -218,7 +218,7 @@ } // Apply loop peeling. scf::ForOp partialIteration; - if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration))) + if (failed(peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration))) return failure(); // Apply label, so that the same loop is not rewritten a second time. rewriter.updateRootInPlace(partialIteration, [&]() {