diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(IR) -add_subdirectory(Transforms) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td) -mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters) -add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen) - -# Including Linalg in TableGen requires to depends on generated files -add_dependencies(MLIRLinalgTransformPatternsIncGen LinalgOdsGen) - diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ /dev/null @@ -1,123 +0,0 @@ -//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the pattern definition file for declarative Linalg transformation. -// -//===----------------------------------------------------------------------===// - -#ifndef LINALG_TRANSFORMS -#define LINALG_TRANSFORMS - -include "mlir/Dialect/Linalg/IR/LinalgOps.td" -include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" -include "mlir/Dialect/Affine/IR/AffineOps.td" - -def HasNoLinalgTransformMarker : CPred<[{ - !op.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) -}]>; - -class HasLinalgTransformMarker : CPred<[{ - op.getAttrOfType( - LinalgTransforms::kLinalgTransformMarker) && - op.getAttrOfType( - LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; - -class IsProducedByOpOfType : - CPred<"isProducedByOpOfType<" # str # ">(op, $0)">; - -class AffineMapDomainHasDim : CPred<[{ - op.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. - cast().getValue().getNumDims() ==}] # n # [{}]>; - -class HasOperandsOfType: CPred<[{ - llvm::any_of(op.getOperands(), - [](Value v) { - return dyn_cast_or_null<}] # type # [{>(v.getDefiningOp()); - }) -}]>; - -//===----------------------------------------------------------------------===// -// Linalg fusion patterns. -//===----------------------------------------------------------------------===// -// -// In the future, tile sizes should be derived from op properties + machine -// description but we do not need to wait on this to start having useful -// patterns. -class TileAndFuseLinalgOp< - list sizes, list operandIndices, string value> : NativeCodeCall< - "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # - " \"" # value # "\")))" # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg tiling patterns. -//===----------------------------------------------------------------------===// -// -// In the future, tile sizes should be derived from op properties + machine -// description but we do not need to wait on this to start having useful -// patterns. -// `permutation` is an optional parameter to specify the ordering of the -// tiled loops. If provided, it must be a list of integers with the same number -// of elements as `sizes`. -class TileLinalgOp sizes, string value, list permutation=[]> : - NativeCodeCall< - "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, \"" # value # "\", {" # - StrJoinInt.result # "})))" # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg to loop patterns. -//===----------------------------------------------------------------------===// -class LinalgOpToLoops : NativeCodeCall< - "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -class LinalgOpToParallelLoops : NativeCodeCall< - "if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -class LinalgOpToAffineLoops : NativeCodeCall< - "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # - " return failure();">; - -//===----------------------------------------------------------------------===// -// Linalg to vector patterns precondition and DRR. -//===----------------------------------------------------------------------===// -def PreconditionVectorizeLinalgOp : CPred< - "succeeded(vectorizeLinalgOpPrecondition(op))">; -def VectorizeLinalgOp : NativeCodeCall< - "vectorizeLinalgOp($_builder, op)">; - - -//===----------------------------------------------------------------------===// -// Linalg generic permutation patterns precondition and DRR. -//===----------------------------------------------------------------------===// -class PreconditionPermuteGenericLinalgOp permutation> : CPred< - "succeeded(permuteGenericLinalgOpPrecondition(op, {" # - StrJoinInt.result # "}))">; -class PermuteGenericLinalgOp permutation, string value> : - NativeCodeCall< - "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt.result # - "}, \"" # value # "\")">; - -//===----------------------------------------------------------------------===// -// Linalg promote subview operands precondition and DRR. -//===----------------------------------------------------------------------===// -def PreconditionPromoteSubviewsLinalgOp : CPred< - "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; -def PromoteSubviewsLinalgOp : NativeCodeCall< - "promoteSubviewsLinalgOp($_builder, op)">; - -class PromoteSelectedSubviewsLinalgOp operands, string marker="", - int alignment=0> : - NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" # - StrJoinInt.result # "}, \"" # marker # "\", " # alignment # ")">; - -#endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ /dev/null @@ -1,137 +0,0 @@ -//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ -#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ - -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -#include "llvm/ADT/STLExtras.h" - -namespace mlir { -namespace linalg { - -// Marker used as attribute name in generated Linalg rewriting transformations. -struct LinalgTransforms { - static const StringLiteral kLinalgTransformMarker; -}; - -namespace detail { -// Implementation detail of isProducedByOpOfType avoids the need for explicit -// template instantiations. -bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView, - function_ref isaOpType); -} // namespace detail - -// Returns true if the `consumedView` value use in `consumerOp` is produced by -// an op of type `OpTy`. This is used to implement use-def type information on -// buffers. -template -bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) { - return detail::isProducedByOpOfTypeImpl( - consumerOp, consumedView, [](Operation *op) { return isa(op); }); -} - -//////////////////////////////////////////////////////////////////////////////// -// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite -// patterns. As such, they must not call into `rewriter.erase/replace` APIs and -// it is the responsibility of the enclosing PatternRewriter to erase on -// success. -//////////////////////////////////////////////////////////////////////////////// - -/// Tiles `op` by `sizes` permuting the loops according to `permutation` and -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. The -/// permutation is expressed as a list of integers that specify the new ordering -/// of the loop nest (using loop.for operations). The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, - ArrayRef sizes, - StringRef linalgMarker, - ArrayRef permutation); - -/// Tiles ops similar to `tileLinalgOpAndSetMarker` but generates loop.parallel -/// operations instead. -LogicalResult tileLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation); - -/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. -LogicalResult tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - -/// Tiles ops similar to `tileAndFuseLinalgOpAndSetMarker` but generates -/// loop.parallel operations instead. -LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker); - -using LinalgLoops = SmallVector; - -/// Emits a loop nest of with the proper body for `op`. -template -Optional linalgLowerOpToLoops(PatternRewriter &rewriter, - Operation *op); - -/// Emits a loop nest of `loop.for` with the proper body for `op`. -template -LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op); - -/// Emits a loop nest of `loop.parallel` with the proper body for `op`. -template -LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op); - -/// Emits a loop nest of `affine.for` with the proper body for `op`. -template -LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); - -/// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeLinalgOpPrecondition(Operation *op); -SmallVector vectorizeLinalgOp(PatternRewriter &rewriter, - Operation *op); - -/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` -/// and `iterator_types` permutated according to `permutation`. -LogicalResult -permuteGenericLinalgOpPrecondition(Operation *op, - ArrayRef permutation); -SmallVector permuteGenericLinalgOp(PatternRewriter &rewriter, - Operation *op, - ArrayRef permutation, - StringRef linalgMarker); - -/// Promote std.subviews feeding linalg operations. -LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op); -SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op); - -/// Similar to `promoteSubviewsLinalgOp` but only tries to promote -/// the views corresponding to the operands specified in -/// `operandIndicesToPromote`. Generated allocations are memory-aligned -/// according to the `alignment` parameter. -/// If linalgMarker is specified and the transformation is successfull -/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. -SmallVector promoteSelectedSubviewsLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, - ArrayRef operandIndicesToPromote, StringRef linalgMarker = "", - int64_t alignment = 0); -} // namespace linalg -} // namespace mlir - -#endif // DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -0,0 +1,363 @@ +//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ +#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace linalg { + +//----------------------------------------------------------------------------// +// Transformations exposed as function calls. +//----------------------------------------------------------------------------// +using LinalgLoops = SmallVector; + +struct TiledLinalgOp { + LinalgOp op; + SmallVector loops; +}; + +/// Performs standalone tiling of a single LinalgOp by `tileSizes`. +/// and permute the loop nest according to `interchangeVector` +/// The permutation is expressed as a list of integers that specify +/// the new ordering of the loop nest. The length of `interchangeVector` +/// must be equal to the length of `tileSizes`. +/// An empty vector is interpreted as the identity permutation and the +/// transformation returns early. +/// +/// When non-null, the optional pointer `folder` is used to call into the +/// `createAndFold` builder method. If `folder` is null, the regular `create` +/// method is called. +/// +/// Returns a struct containing the tiled loops in the specified order +/// and the cloned op if successful, llvm::None otherwise. +/// +/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by +/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be +/// integers, in the range 0..`tileSizes.size()` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); +Optional +tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); + +/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. +/// See `tileLinalgOp(... ArrayRef tileSizes,)` for more details +Optional tileLinalgOp(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); +Optional +tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + OperationFolder *folder = nullptr); + +/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. +/// This is an in-place transformation controlled by `interchangeVector`. +/// An empty vector is interpreted as the identity permutation and the +/// transformation returns early. +/// +/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with +/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be +/// integers, in the range 0..`op.rank` without duplications +/// (i.e. `[1,1,2]` is an invalid permutation). +LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); + +/// Promotes the `subViews` into a new buffer allocated at the insertion point +/// `b`. Promotion occurs in 3 steps: +/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). +/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use +/// float zero for now). +/// 3. Take a partial slice of the full view in step 2. and copy into it. +/// Infers statically sized buffers from subViews unless `dynamicBuffers` is +/// true. +/// +/// Returns a list of PromotionInfo which hold the promoted buffer and the +/// full and partial views indexing into the buffer. +// TODO: revisit dynamicBuffers option. +LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, + llvm::SetVector subViews, + bool dynamicBuffers = false, + int64_t alignment = 0, + OperationFolder *folder = nullptr); + +/// Emit a suitable vector form for a Linalg op with fully static shape. +void vectorizeLinalgOp(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `LoopTy` with the proper body for `op`. +template +Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.parallel` with the proper body for `op`. +template +LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); + +//----------------------------------------------------------------------------// +// Preconditions that ensure the corresponding transformation suceeds and can be +// applied as a rewrite pattern. +//----------------------------------------------------------------------------// +/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` +/// and `iterator_types` permutated according to `permutation`. +LogicalResult +interchangeGenericLinalgOpPrecondition(Operation *op, + ArrayRef interchangeVector); + +/// Promote std.subviews feeding linalg operations. +LogicalResult promoteSubviewsLinalgOpPrecondition( + Operation *op, Optional> operandIndicesToPromote = None); + +/// Rewrite a linalg.generic into a suitable vector.contraction op. +LogicalResult vectorizeLinalgOpPrecondition(Operation *op); + +//----------------------------------------------------------------------------// +// Transformations exposed as rewrite patterns. +//----------------------------------------------------------------------------// +// Marker used as attribute name in generated Linalg rewriting transformations. +struct LinalgTransforms { + static const StringLiteral kLinalgTransformMarker; +}; + +/// Helper class to control common attribute matching and setting behavior. +struct LinalgMarker { + LinalgMarker(ArrayRef matchDisjunction = {}, + Optional replacement = None); + LinalgMarker(ArrayRef matchDisjunction, StringRef replacement); + LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; + void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const; + +private: + SmallVector matchDisjunction; + Optional replacement; +}; + +/// +/// Linalg tiling patterns. +/// +/// Apply the `tileLinalgOp` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `tileLinalgOp` for more details. +struct LinalgBaseTilingPattern : public RewritePattern { + LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// The tile sizes by which to tile. + SmallVector tileSizes; + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector; +}; + +template +struct LinalgTilingPattern : public LinalgBaseTilingPattern { + LinalgTilingPattern(MLIRContext *context, ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseTilingPattern(OpTy::getOperationName(), context, tileSizes, + interchangeVector, marker, benefit) {} + LinalgTilingPattern(MLIRContext *context, ArrayRef tileSizes, + LinalgMarker marker, PatternBenefit benefit = 1) + : LinalgTilingPattern(context, tileSizes, {}, marker, benefit) {} +}; + +/// +/// Linalg interchange patterns. +/// +/// Apply the `interchange` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `interchange` for more details. +struct LinalgBaseInterchangePattern : public RewritePattern { + LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// The interchange vector to reorder the iterators and indexing_maps dims. + SmallVector interchangeVector; +}; + +template +struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { + LinalgInterchangePattern(MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseInterchangePattern(OpTy::getOperationName(), context, + interchangeVector, marker, benefit) {} +}; + +/// +/// Linalg promotion patterns. +/// +/// Apply the `promoteSubViewOperands` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `promoteSubViewOperands` for more details. +struct LinalgBasePromotionPattern : public RewritePattern { + LinalgBasePromotionPattern(StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// Indices of subViews to promote. + SmallVector operandsToPromote; + /// Alignment of promoted buffer. + unsigned alignment; +}; + +template +struct LinalgPromotionPattern : public LinalgBasePromotionPattern { + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBasePromotionPattern(OpTy::getOperationName(), context, + operandsToPromote, alignment, marker, + benefit) {} + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) { + } + LinalgPromotionPattern(MLIRContext *context, unsigned alignment, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, {}, alignment, marker, benefit) {} + LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker, + PatternBenefit benefit = 1) + : LinalgPromotionPattern(context, {}, 0, marker, benefit) {} +}; + +/// +/// Linalg vectorization patterns. +/// +/// Apply the `vectorizeLinalgOp` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `vectorizeLinalgOp` for more details. +struct LinalgBaseVectorizationPattern : public RewritePattern { + LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; +}; + +template +struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { + LinalgVectorizationPattern(MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context, + marker, benefit) {} +}; + +/// +/// Linalg lowering patterns. +/// +/// Apply the `linalgLowerOpToLoops` transformation as a pattern. +/// `marker` controls LinalgTransformMarker matching and update when specified. +/// See `linalgLowerOpToLoops` for more details. +enum class LinalgLoweringType { + LibraryCall = 0, + Loops = 1, + AffineLoops = 2, + ParallelLoops = 3 +}; +template +struct LinalgLoweringPattern : public RewritePattern { + LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1) + : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + marker(marker), loweringType(loweringType) {} + // TODO: Move implementation to .cpp once named ops are auto-generated. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(promoteSubviewsLinalgOpPrecondition(op))) + return failure(); + + if (loweringType == LinalgLoweringType::LibraryCall) { + // TODO: Move lowering to library calls here. + return failure(); + } else if (loweringType == LinalgLoweringType::Loops) { + if (failed(linalgOpToLoops(rewriter, op))) + return failure(); + } else if (loweringType == LinalgLoweringType::AffineLoops) { + if (failed(linalgOpToAffineLoops(rewriter, op))) + return failure(); + } else { + if (failed(linalgOpToParallelLoops(rewriter, op))) + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgMarker marker; + /// Controls whether the pattern lowers to library calls, loop.for, affine.for + /// or loop.parallel. + LinalgLoweringType loweringType; +}; + +} // namespace linalg +} // namespace mlir + +#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -101,63 +101,6 @@ AffineMap map, ArrayRef values, OperationFolder *folder = nullptr); -struct TiledLinalgOp { - LinalgOp op; - SmallVector loops; -}; - -/// Performs standalone tiling of a single LinalgOp by `tileSizes`. -/// and permute the loop nest according to `permutation` -/// The permutation is expressed as a list of integers that specify -/// the new ordering of the loop nest. The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -/// Returns a struct containing the tiled loops in the specified order -/// and the cloned op if successful, llvm::None otherwise. -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation = {}, - OperationFolder *folder = nullptr); -Optional tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation = {}, OperationFolder *folder = nullptr); - -/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. -/// and permute the loop nest according to `permutation` -/// The permutation is expressed as a list of integers that specify -/// the new ordering of the loop nest. The length of `permutation` -/// must be equal to the length of `tileSizes`. -/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with -/// `permutation = [1,2,0]`. All values in `permutation` must be -/// integers, in the range 0..`tileSizes.size()` without duplications -/// (i.e. `[1,1,2]` is an invalid permutation). An empty list -/// states for the identity permutation. -/// Returns a struct containing the tiled loops in the specified order -/// and the cloned op if successful, llvm::None otherwise. -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation = {}, - OperationFolder *folder = nullptr); -Optional tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation = {}, OperationFolder *folder = nullptr); - -template -Optional tileLinalgOperation(OpBuilder &b, Operation *op, - Args... args) { - return tileLinalgOp(b, cast(op), args...); -} - struct PromotionInfo { Value buffer; Value fullLocalView; @@ -198,17 +141,6 @@ inVec = auxVec; } -/// Prepares the SubView promotion later performed by `promoteSubViews` -/// (where most of the transformation happens). It arranges the new -/// operands for `LinalgOp op` and deallocates the new buffer(s) -/// It is the entry point for declarative transformation -/// Returns the cloned `LinalgOp` with the new operands -LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, - llvm::SetVector subViews, - bool dynamicBuffers = false, - int64_t alignment = 0, - OperationFolder *folder = nullptr); - } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms Fusion.cpp - LinalgTransforms.cpp - LinalgToLoops.cpp + Interchange.cpp + Loops.cpp Promotion.cpp Tiling.cpp + Transforms.cpp + Vectorization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg @@ -11,7 +13,6 @@ DEPENDS intrinsics_gen MLIRLinalgPassIncGen - MLIRLinalgTransformPatternsIncGen ) target_link_libraries(MLIRLinalgTransforms PUBLIC diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -0,0 +1,85 @@ +//===- Interchange.cpp - Linalg interchange transformation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg interchange transformation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-interchange" + +using namespace mlir; +using namespace mlir::linalg; + +LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( + Operation *op, ArrayRef interchangeVector) { + if (interchangeVector.empty()) + return failure(); + // Transformation applies to generic ops only. + if (!isa(op) && !isa(op)) + return failure(); + LinalgOp linOp = cast(op); + // Transformation applies to buffers only. + if (!linOp.hasBufferSemantics()) + return failure(); + // Permutation must be applicable. + if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) + return failure(); + // Permutation map must be invertible. + if (!inversePermutation( + AffineMap::getPermutationMap(interchangeVector, op->getContext()))) + return failure(); + return success(); +} + +LinalgOp mlir::linalg::interchange(LinalgOp op, + ArrayRef interchangeVector) { + if (interchangeVector.empty()) + return op; + + MLIRContext *context = op.getContext(); + auto permutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, context)); + assert(permutationMap && "expected permutation to be invertible"); + SmallVector newIndexingMaps; + auto indexingMaps = op.indexing_maps().getValue(); + for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) { + AffineMap m = indexingMaps[i].cast().getValue(); + if (!permutationMap.isEmpty()) + m = m.compose(permutationMap); + newIndexingMaps.push_back(AffineMapAttr::get(m)); + } + auto itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + for (unsigned i = 0, e = itTypes.size(); i != e; ++i) + itTypesVector.push_back(itTypes[i]); + applyPermutationToVector(itTypesVector, interchangeVector); + + op.setAttr(getIndexingMapsAttrName(), + ArrayAttr::get(newIndexingMaps, context)); + op.setAttr(getIteratorTypesAttrName(), + ArrayAttr::get(itTypesVector, context)); + + return op; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ /dev/null @@ -1,381 +0,0 @@ -//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements logic for transforming Linalg operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "linalg-transforms" - -using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; -using namespace mlir::linalg; - -using llvm::dbgs; -using llvm::SetVector; - -// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = - "__internal_linalg_transform__"; - -using TileFn = Optional(OpBuilder &, LinalgOp, ArrayRef, - ArrayRef, OperationFolder *); - -static LogicalResult -tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - StringRef linalgMarker, - ArrayRef permutation) { - assert(permutation.empty() || permutation.size() == sizes.size()); - auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return success(); -} - -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes, - linalgMarker, permutation); -} -LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op, - sizes, linalgMarker, permutation); -} - -static LogicalResult -tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, - StringRef linalgMarker) { - auto tileRes = - tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, op->getParentOfType()); - SmallVector originalProducers; - for (auto operandIdx : operandIndicesToFuse) { - auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G); - if (!fusionRes) { - // Linalg fusion requires tiled loops to even determine whether it is - // possible to fuse. As a consequence, the pattern may fail even though a - // tiled version of op has already been introduced. - // So we need to remove the tiled version ourselves in case of failure. - // Another possibility is to ensure the constraints on the pattern - // guarantee that fusion will occur and just assert here. As we develop - // more complex patterns we can choose what is best. - rewriter.eraseOp(tileRes->loops[0]); - return failure(); - } - fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - originalProducers.push_back(fusionRes->originalProducer); - } - - // The originalProducers can now be safely erased. This is similar to - // SSA-value use-def but in the world of buffer + structured ops. - for (auto *originalProducer : originalProducers) - rewriter.eraseOp(originalProducer); - return success(); -} - -LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker); -} -LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse, - linalgMarker); -} - -bool mlir::linalg::detail::isProducedByOpOfTypeImpl( - Operation *consumerOp, Value consumedView, - function_ref isaOpType) { - LinalgOp consumer = dyn_cast(consumerOp); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (!consumer) - return false; - - auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); - if (!maybeConsumerIndex) - return false; - - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, consumer.getParentOfType()); - for (auto dependence : G.getDependencesInto( - consumer, LinalgDependenceGraph::DependenceType::RAW)) { - auto producer = cast(dependence.dependentOpView.op); - if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) - continue; - if (isaOpType(dependence.dependentOpView.op)) - return true; - } - return false; -} - -//============================================================================// -// Precondition and transformation for vectorization of Linalg generic ops. -//============================================================================// -static bool hasMultiplyAddBody(linalg::GenericOp op) { - auto &r = op.region(); - if (r.empty()) - return false; - if (r.getBlocks().size() != 1) - return false; - auto &ops = r.front().getOperations(); - if (ops.size() != 3) - return false; - - using mlir::matchers::m_Val; - auto a = m_Val(r.front().getArgument(0)); - auto b = m_Val(r.front().getArgument(1)); - auto c = m_Val(r.front().getArgument(2)); - // TODO(ntv) Update this detection once we have matcher support for - // specifying that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); - return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || - pattern3.match(&ops.back()) || pattern4.match(&ops.back()); -} - -// TODO(ntv) should be Tablegen'd from a single source that generates the op -// itself. -static bool isRowMajorMatmul(linalg::GenericOp genericOp) { - return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - isRowMajorMatmul(genericOp.indexing_maps()) && - hasMultiplyAddBody(genericOp); -} - -// TODO(ntv, ataei): This is in fact much more general than just vectorization -// for matmul and fill ops. -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { - auto linalgOp = cast(op); - // All types must be static shape to go to vector. - for (Value operand : linalgOp.getInputsAndOutputBuffers()) - if (!operand.getType().cast().hasStaticShape()) - return failure(); - for (Type outputTensorType : linalgOp.getOutputTensorTypes()) - if (!outputTensorType.cast().hasStaticShape()) - return failure(); - if (isa(op) || isa(op)) - return success(); - - auto genericOp = dyn_cast(op); - if (!genericOp || !::isRowMajorMatmul(genericOp)) - return failure(); - - // TODO(ntv): non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](Value v) { - auto m = v.getType().dyn_cast(); - if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) - return false; - return true; - }; - if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), - isStaticMemRefWithIdentityLayout)) - return failure(); - return success(); -} - -SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, - Operation *op) { - assert(succeeded(vectorizeLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - auto linalgOp = cast(op); - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - edsc::ScopedContext scope(rewriter, op->getLoc()); - - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg.fill as vector.broadcast: " - << *op << ":\n"); - auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); - Value dstVec = std_load(dstMemrefVec); - auto resVec = vector_broadcast(dstVec.getType(), fillOp.value()); - std_store(resVec, dstMemrefVec); - } else { - // Vectorize other ops as vector contraction (currently only matmul). - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); - auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); - auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); - auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - std_store(vRes, vectorMemRefC); - } - return {}; -} - -//============================================================================// -// Precondition and transformation for permutation of Linalg generic ops. -//============================================================================// -LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( - Operation *op, ArrayRef permutation) { - if (permutation.empty()) - return failure(); - // Transformation applies to generic ops only. - if (!isa(op) && !isa(op)) - return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, - ArrayRef permutation, - StringRef linalgMarker) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op - << ":\n"); - - assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && - "DRR failure case must be a precondition"); - - auto linOp = cast(op); - auto permutationMap = inversePermutation( - AffineMap::getPermutationMap(permutation, rewriter.getContext())); - assert(permutationMap && "expected permutation to be invertible"); - SmallVector newIndexingMap; - auto indexingMaps = linOp.indexing_maps().getValue(); - for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { - AffineMap m = indexingMaps[i].cast().getValue(); - if (!permutationMap.isEmpty()) - m = m.compose(permutationMap); - newIndexingMap.push_back(m); - } - auto itTypes = linOp.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, permutation); - op->setAttr(getIndexingMapsAttrName(), - rewriter.getAffineMapArrayAttr(newIndexingMap)); - op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector)); - op->setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); - return {}; -} - -//============================================================================// -// Precondition and transformation for Linalg subview promotion. -//============================================================================// -LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { - LinalgOp linOp = dyn_cast(op); - // Transformation applies to buffers only. - if (!linOp || !linOp.hasBufferSemantics()) - return failure(); - if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { - return isa_and_nonnull(v.getDefiningOp()); - })) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - LinalgOp linOp = cast(op); - SmallVector toPromote; - int64_t nBuffers = linOp.getNumInputsAndOutputBuffers(); - toPromote.reserve(nBuffers); - for (int64_t i = 0; i < nBuffers; ++i) - toPromote.push_back(i); - return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); -} - -SmallVector mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, - ArrayRef operandIndicesToPromote, StringRef linalgMarker, - int64_t alignment) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - LinalgOp linOp = cast(op); - assert(linOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - SetVector subViews; - for (int64_t index : operandIndicesToPromote) - if (auto sv = - dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) - subViews.insert(sv); - - if (!subViews.empty()) { - auto newOp = - promoteSubViewOperands(rewriter, linOp, subViews, false, alignment); - if (!linalgMarker.empty()) - newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return {}; - } - llvm_unreachable("DRR failure case must be a precondition"); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp rename from mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp rename to mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -1,4 +1,4 @@ -//===- LinalgToLoops.cpp - conversion from Linalg library ops to loops-----===// +//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,7 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -489,20 +489,6 @@ } }; -/// This struct is for factoring out the implementation and support template -/// instantiations in the following 2 cases: -/// 1. Appending to a list of patterns via RewritePatternList. -/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. -/// The implementation must work both in DRR and inside a RewritePattern. As a -/// consequence, (1) it is only allowed to emit new ops if the match is -/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an -/// encompassing pattern must take care of the erasure logic. -template -class LinalgOpToLoopsImpl { -public: - static Optional doit(Operation *op, PatternRewriter &rewriter); -}; - namespace { /// Helper struct to generate the loop nest for the op. This factored out here /// to be able to partially specialize this for different LoopTy. @@ -573,14 +559,12 @@ } // namespace template -Optional -LinalgOpToLoopsImpl::doit(Operation *op, - PatternRewriter &rewriter) { +Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { using Impl = GenerateLoopNest; using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(rewriter, op->getLoc()); + ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). @@ -607,7 +591,7 @@ SmallVector allIvs(nLoops); auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, - getViewSizes(rewriter, linalgOp)); + getViewSizes(builder, linalgOp)); assert(loopRanges.size() == allIvs.size()); Impl::doit(linalgOp, loopRanges, allIvs); // Number of loop ops might be different from the number of ivs since some @@ -635,8 +619,7 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - using Impl = LinalgOpToLoopsImpl; - if (!Impl::doit(op, rewriter)) + if (!linalgOpToLoopsImpl(op, rewriter)) return failure(); rewriter.eraseOp(op); return success(); @@ -662,7 +645,7 @@ } }; -/// Populate the given list with patterns that convert from Linalg to LLVM. +/// Populate the given list with patterns that convert from Linalg to loops. template void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { RewritePatternList -Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit(op, rewriter); +Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op) { + return linalgOpToLoopsImpl(op, builder); } /// Emits a loop nest of `loop.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - Operation *op) { +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `affine.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `loop.parallel` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } @@ -795,14 +777,14 @@ // need to update as soon as we add new ops. #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ template LogicalResult mlir::linalg::linalgOpToLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToParallelLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template Optional \ mlir::linalg::linalgLowerOpToLoops( \ - PatternRewriter & rewriter, Operation * op); + OpBuilder & builder, Operation * op); INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -264,6 +265,21 @@ op.erase(); } +LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition( + Operation *op, llvm::Optional> operandIndicesToPromote) { + LinalgOp linOp = dyn_cast(op); + // Transformation applies to buffers only. + if (!linOp || !linOp.hasBufferSemantics()) + return failure(); + for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) { + auto sv = isa_and_nonnull(en.value().getDefiningOp()); + if (sv && (!operandIndicesToPromote.hasValue() || + operandIndicesToPromote->count(en.index()))) + return success(); + } + return failure(); +} + namespace { struct LinalgPromotionPass : public LinalgPromotionBase { LinalgPromotionPass() = default; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -320,10 +321,9 @@ } template -Optional static tileLinalgOpImpl(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef permutation, - OperationFolder *folder) { +Optional static tileLinalgOpImpl( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef interchangeVector, OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -342,13 +342,13 @@ return llvm::None; } - // If permutation is empty, use the identity. Build the permutation map + // If interchangeVector is empty, use the identity. Build the permutation map // otherwise. auto invPermutationMap = AffineMap::getMultiDimIdentityMap( tileSizes.size(), ScopedContext::getContext()); - if (!permutation.empty()) - invPermutationMap = inversePermutation( - AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + if (!interchangeVector.empty()) + invPermutationMap = inversePermutation(AffineMap::getPermutationMap( + interchangeVector, ScopedContext::getContext())); if (!invPermutationMap) return llvm::None; @@ -371,8 +371,8 @@ std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, viewSizes, tileSizes, folder); - if (!permutation.empty()) - applyPermutationToVector(loopRanges, permutation); + if (!interchangeVector.empty()) + applyPermutationToVector(loopRanges, interchangeVector); // 3. Create the tiled loops. LinalgOp res = op; @@ -393,7 +393,7 @@ // assuming that loopRanges have previously been permuted by // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of // that one: (d0,d1,d2)->(d2,d0,d1) - if (!permutation.empty()) + if (!interchangeVector.empty()) ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); auto views = @@ -420,7 +420,8 @@ template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { + ArrayRef interchangeVector, + OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -459,33 +460,36 @@ tileSizeValues.push_back(folded_std_constant_index(folder, 0)); } - return tileLinalgOpImpl(b, op, tileSizeValues, permutation, folder); + return tileLinalgOpImpl(b, op, tileSizeValues, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, + ArrayRef interchangeVector, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOpToParallelLoops( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, folder); } Optional mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, + folder); } Optional mlir::linalg::tileLinalgOpToParallelLoops( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef permutation, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, permutation, + ArrayRef interchangeVector, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, folder); } @@ -496,8 +500,8 @@ f.walk([tileSizes, &b, &folder](LinalgOp op) { if (!op.hasBufferSemantics()) return; - auto opLoopsPair = - tileLinalgOpImpl(b, op, tileSizes, /*permutation=*/{}, &folder); + auto opLoopsPair = tileLinalgOpImpl( + b, op, tileSizes, /*interchangeVector=*/{}, &folder); // If tiling occurred successfully, erase old op. if (opLoopsPair) op.erase(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -0,0 +1,225 @@ +//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic and helpers to expose Linalg transforms as rewrite +// patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-transforms" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +//----------------------------------------------------------------------------// +// Transformations exposed as rewrite patterns. +//----------------------------------------------------------------------------// +// Marker used as attribute name in generated Linalg rewriting transformations. +const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = + "__internal_linalg_transform__"; + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + llvm::Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement) {} + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + StringRef replacement) + : LinalgMarker(matchDisjunction, llvm::Optional{replacement}) {} + +LogicalResult +mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { + auto attr = op->template getAttrOfType( + LinalgTransforms::kLinalgTransformMarker); + + if (!attr) { + // 1. Has no marker case and matchDisjunction is empty. + if (matchDisjunction.empty()) + return success(); + + // 2. Has no marker and matchDisjuntion matches the no-moarker case. + for (auto marker : matchDisjunction) + if (marker.empty()) + return success(); + + // 3. Has no marker but was expecting a marker. + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); + } + + // 4. Match explicit marker. + for (auto marker : matchDisjunction) + if (attr.getValue() == marker) + return success(); + + // 5. Fail to match. + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); +} + +void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, + Operation *op) const { + if (replacement.hasValue()) + op->setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(replacement.getValue())); + else + op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, + rewriter.getContext())); +} + +/// Linalg base tiling pattern. +mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( + StringRef opName, MLIRContext *context, ArrayRef tileSizes, + ArrayRef interchangeVector, LinalgMarker marker, + PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + tileSizes(tileSizes.begin(), tileSizes.end()), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + +LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + auto tileRes = + (true) ? tileLinalgOp(rewriter, linalgOp, tileSizes, interchangeVector) + : tileLinalgOpToParallelLoops(rewriter, linalgOp, tileSizes, + interchangeVector); + if (!tileRes) + return failure(); + + // New marker if specified. + marker.replaceLinalgMarker(rewriter, tileRes->op.getOperation()); + + rewriter.eraseOp(op); + return success(); +} + +/// Linalg base interchange pattern. +mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( + StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, LinalgMarker marker, + PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + +LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) + return failure(); + + // TODO(ntv): figure out how this interplays with named ops. In particular + // this should break the named op property. + rewriter.updateRootInPlace(op, [&]() { + interchange(linalgOp, interchangeVector); + // New marker if specified. + marker.replaceLinalgMarker(rewriter, op); + }); + return success(); +} + +mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( + StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote, unsigned alignment, + LinalgMarker marker, PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()), + alignment(alignment) {} + +LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (operandsToPromote.empty()) { + if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None))) + return failure(); + } else { + DenseSet set; + set.insert(operandsToPromote.begin(), operandsToPromote.end()); + if (failed(promoteSubviewsLinalgOpPrecondition(op, set))) + return failure(); + } + + llvm::SetVector subViews; + if (!operandsToPromote.empty()) { + for (unsigned idx : operandsToPromote) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } else { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + for (unsigned idx = 0; idx < nBuffers; ++idx) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } + + auto promotedOp = + promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false, + /*alignment=*/alignment); + marker.replaceLinalgMarker(rewriter, promotedOp.getOperation()); + rewriter.eraseOp(op); + return success(); +} + +mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( + StringRef opName, MLIRContext *context, LinalgMarker marker, + PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker) {} + +LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(vectorizeLinalgOpPrecondition(op))) + return failure(); + vectorizeLinalgOp(rewriter, op); + rewriter.eraseOp(op); + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -0,0 +1,132 @@ +//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Vectorization transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +#define DEBUG_TYPE "linalg-vectorization" + +static bool hasMultiplyAddBody(linalg::GenericOp op) { + auto &r = op.region(); + if (!llvm::hasSingleElement(r)) + return false; + if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) + return false; + + using mlir::matchers::m_Val; + auto a = m_Val(r.front().getArgument(0)); + auto b = m_Val(r.front().getArgument(1)); + auto c = m_Val(r.front().getArgument(2)); + // TODO(ntv) Update this detection once we have matcher support for + // specifying that any permutation of operands matches. + auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + return pattern1.match(&r.front().back()) || + pattern2.match(&r.front().back()) || + pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); +} + +// TODO(ntv) should be Tablegen'd from a single source that generates the op +// itself. +static bool isRowMajorMatmul(linalg::GenericOp genericOp) { + return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + isRowMajorMatmul(genericOp.indexing_maps()) && + hasMultiplyAddBody(genericOp); +} + +// TODO(ntv, ataei): This is in fact much more general than just vectorization +// for matmul and fill ops. +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { + auto linalgOp = cast(op); + // All types must be static shape to go to vector. + for (Value operand : linalgOp.getInputsAndOutputBuffers()) + if (!operand.getType().cast().hasStaticShape()) + return failure(); + for (Type outputTensorType : linalgOp.getOutputTensorTypes()) + if (!outputTensorType.cast().hasStaticShape()) + return failure(); + if (isa(op) || isa(op)) + return success(); + + auto genericOp = dyn_cast(op); + if (!genericOp || !::isRowMajorMatmul(genericOp)) + return failure(); + + // TODO(ntv): non-identity layout. + auto isStaticMemRefWithIdentityLayout = [](Value v) { + auto m = v.getType().dyn_cast(); + if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) + return false; + return true; + }; + return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(), + isStaticMemRefWithIdentityLayout)); +} + +void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { + assert(succeeded(vectorizeLinalgOpPrecondition(op))); + + if (auto convOp = dyn_cast(op)) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + + edsc::ScopedContext scope(builder, op->getLoc()); + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.fill as vector.broadcast: " + << *op << ":\n"); + Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); + Value dst = std_load(memref); + Value res = vector_broadcast(dst.getType(), fillOp.value()); + std_store(res, memref); + return; + } + + // Vectorize other ops as vector contraction (currently only matmul). + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + auto linalgOp = cast(op); + Value a = std_load(vector_type_cast(linalgOp.getInput(0))); + Value b = std_load(vector_type_cast(linalgOp.getInput(1))); + Value memref = vector_type_cast(linalgOp.getOutputBuffer(0)); + Value c = std_load(memref); + Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), + linalgOp.iterator_types()); + std_store(res, memref); +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -25,7 +25,6 @@ // CHECK-DAG: %[[c8:.*]] = constant 8 : index // CHECK-DAG: %[[c8000:.*]] = constant 8000 : index // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { // CHECK: load // CHECK: load @@ -86,88 +85,6 @@ // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -#some_generic_trait = { - args_in = 1, - args_out = 1, - indexing_maps = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)> - ], - iterator_types = ["parallel", "parallel"] -} -func @fusion_test(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref) { - // This should not be fused as it would violate dependencies. It will get - // tiled for all levels of the memory hierarchy. - linalg.matmul(%A, %A, %C) : memref, - memref, - memref - - // This should be fused. - linalg.matmul(%A, %B, %C) : memref, - memref, - memref - - // This should not be fused or transformed at all since there are no patterns - // on it. However it will be reordered because there are no dependencies. - linalg.generic #some_generic_trait %A, %D { - ^bb(%a: f32, %b: f32) : - linalg.yield %a : f32 - } : memref, - memref - - linalg.matmul(%C, %D, %E) : memref, - memref, - memref - - return -} -// CHECK-LABEL: func @fusion_test -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c2:.*]] = constant 2 : index -// CHECK-DAG: %[[c3:.*]] = constant 3 : index -// CHECK-DAG: %[[c4:.*]] = constant 4 : index -// CHECK-DAG: %[[c20:.*]] = constant 20 : index -// CHECK-DAG: %[[c30:.*]] = constant 30 : index -// CHECK-DAG: %[[c40:.*]] = constant 40 : index -// CHECK-DAG: %[[c100:.*]] = constant 100 : index -// CHECK-DAG: %[[c150:.*]] = constant 150 : index -// CHECK-DAG: %[[c200:.*]] = constant 200 : index -// CHECK-DAG: %[[c300:.*]] = constant 300 : index -// CHECK-DAG: %[[c400:.*]] = constant 400 : index -// CHECK-DAG: %[[c2000:.*]] = constant 2000 : index -// CHECK-DAG: %[[c3000:.*]] = constant 3000 : index -// CHECK-DAG: %[[c4000:.*]] = constant 4000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { -// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -// -// CHECK: linalg.generic -// -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref - #matmul_trait = { args_in = 2, args_out = 1, @@ -280,23 +197,6 @@ // CHECK-SAME: memref, // CHECK-SAME: memref -func @dot_perm(%x: memref, - %y: memref, - %v: memref) { - linalg.dot(%x, %y, %v) {__internal_linalg_transform__ = "__with_perm__"} : - memref, - memref, - memref - return -} -// CHECK-LABEL: func @dot_perm -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c8:.*]] = constant 8 : index -// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { -// CHECK: linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref - func @matvec_perm(%A: memref, %x: memref, %y: memref) { diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -1,9 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td) -mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters) -add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) -# Including Linalg in TableGen requires to depends on generated files -add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen) - set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td deleted file mode 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ /dev/null @@ -1,168 +0,0 @@ -//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the pattern definition file for declarative Linalg transformations -// tests. -// -//===----------------------------------------------------------------------===// - -#ifndef TEST_LINALG_TRANSFORMS_PATTERNS -#define TEST_LINALG_TRANSFORMS_PATTERNS - -include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" - -//===----------------------------------------------------------------------===// -// Test Linalg fusion patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $_, $_), - (TileAndFuseLinalgOp<[100, 150], [0], "L1">), - [ - (Constraint), - (Constraint> $A), - ], - // In the buffer world there is no use-def chains or dags so benefits - // cannot be computed automatically from the length of the matched - // pattern. Instead we specify the benefit ourselves for now. - // This is not expected to be a big challenge long-term because - // pattern benefits are akin to feature engineering: features should - // be learned. - (addBenefit 1)>; - -//===----------------------------------------------------------------------===// -// Linalg tiling patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L3">), - [(Constraint]>>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L2">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "L1">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2, 3, 4], "REG">), - [(Constraint>)]>; - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1">)], - [(Constraint)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1">)], - [(Constraint, - HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg tiling and permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), - [(Constraint>)]>; - - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], - [(Constraint>)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1__with_perm__">)], - [(Constraint>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG__with_perm__">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to loops patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $_, $_, $_), - [(LinalgOpToLoops<"DotOp">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to vector contraction patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(MatmulOp:$op $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(FillOp:$op $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; - - -//===----------------------------------------------------------------------===// -// Linalg generic permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -//===----------------------------------------------------------------------===// -// Linalg subview operands promotion. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSubviewsLinalgOp), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_">]>> - )]>; - -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">), - [(Constraint, - HasLinalgTransformMarker<"_promote_first_view_">]>> - )]>; - -def : Pat<(FillOp:$op $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "aligned_promotion", 32>), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_aligned_">]>> - )]>; - -#endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -25,7 +25,6 @@ DEPENDS MLIRStandardOpsIncGen - MLIRTestLinalgTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen ) diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -10,36 +10,132 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetVector.h" + using namespace mlir; using namespace mlir::linalg; -namespace mlir { -namespace linalg { -namespace { -#include "TestLinalgTransformPatterns.h.inc" -} // end namespace -} // end namespace linalg -} // end namespace mlir - namespace { struct TestLinalgTransforms : public PassWrapper { + TestLinalgTransforms() = default; + TestLinalgTransforms(const TestLinalgTransforms &pass) {} + void runOnFunction() override; + + Option testPatterns{*this, "test-patterns", + llvm::cl::desc("Test a mixed set of patterns"), + llvm::cl::init(false)}; }; } // end anonymous namespace -/// Apply transformations specified as patterns. -void TestLinalgTransforms::runOnFunction() { +static void applyPatterns(FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; - auto funcOp = getFunction(); - // Add the generated patterns to the list. - linalg::populateWithGenerated(&getContext(), &patterns); + //===--------------------------------------------------------------------===// + // Linalg tiling patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2000, 3000, 4000}, + LinalgMarker({"MEM", {}}, "L3")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{200, 300, 400}, + LinalgMarker({"L3"}, "L2")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{20, 30, 40}, LinalgMarker({"L2"}, "L1")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2, 3, 4}, LinalgMarker({"L1"}, "REG")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{5, 6}, LinalgMarker({}, "L1")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{8000}, + LinalgMarker({"MEM", "L3", "L2", {}}, "L1")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{8}, LinalgMarker({"L1"}, "REG")); + + //===--------------------------------------------------------------------===// + // Linalg tiling and permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2000, 3000, 4000}, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({"__with_perm__"}, "L2__with_perm__")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{200, 300, 400}, + /*interchangeVector=*/ArrayRef{1, 0, 2}, + LinalgMarker({"L2__with_perm__"}, "L1__with_perm__")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{20, 30, 40}, + LinalgMarker({"L1__with_perm__"}, "REG__with_perm__")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{5, 6}, + /*interchangeVector=*/ArrayRef{1, 0}, + LinalgMarker({"__with_perm__"}, "L1__with_perm__")); + + //===--------------------------------------------------------------------===// + // Linalg to loops patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"})); + + //===--------------------------------------------------------------------===// + // Linalg to vector contraction patterns. + //===--------------------------------------------------------------------===// + patterns.insert, + LinalgVectorizationPattern, + LinalgVectorizationPattern>( + ctx, LinalgMarker({"VECTORIZE"})); + + //===--------------------------------------------------------------------===// + // Linalg generic permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + + //===--------------------------------------------------------------------===// + // Linalg subview operands promotion. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + /*alignment=*/32, + LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. @@ -48,9 +144,15 @@ }); } +/// Apply transformations specified as patterns. +void TestLinalgTransforms::runOnFunction() { + if (testPatterns) + return applyPatterns(getFunction()); +} + namespace mlir { void registerTestLinalgTransforms() { - PassRegistration( + PassRegistration testTransformPatternsPass( "test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); }