diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(IR) -add_subdirectory(Transforms) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -0,0 +1,260 @@ +//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ +#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace linalg { + +//============================================================================// +// Transformations exposed as function calls. +//============================================================================// +using LinalgLoops = SmallVector; + +/// Emits a loop nest of with the proper body for `op`. +template +Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.for` with the proper body for `op`. +template +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `loop.parallel` with the proper body for `op`. +template +LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `affine.for` with the proper body for `op`. +template +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); + +LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); + +/// Rewrite a linalg.generic into a suitable vector op. +/// Return the newly created vector op. +Operation *vectorizeLinalgOp(OpBuilder &builder, Operation *op); + +//============================================================================// +// Preconditions that ensure the corresponding transformation suceeds and can be +// applied as a rewrite pattern. +//============================================================================// +/// Rewrite a linalg.generic into a suitable vector.contraction op. +LogicalResult vectorizeLinalgOpPrecondition(Operation *op); + +/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` +/// and `iterator_types` permutated according to `permutation`. +LogicalResult +interchangeGenericLinalgOpPrecondition(Operation *op, + ArrayRef permutation); + +/// Promote std.subviews feeding linalg operations. +LogicalResult promoteSubviewsLinalgOpPrecondition( + Operation *op, + llvm::Optional> operandIndicesToPromote = llvm::None); + +//============================================================================// +// Transformations exposed as rewrite patterns. +//============================================================================// +// Marker used as attribute name in generated Linalg rewriting transformations. +struct LinalgTransforms { + static const StringLiteral kLinalgTransformMarker; +}; + +/// Helper class to control common attribute matching and setting behavior. +struct LinalgMarker { + LinalgMarker(ArrayRef matchDisjunction = {}, + llvm::Optional replacement = llvm::None); + LinalgMarker(ArrayRef matchDisjunction, StringRef replacement); + LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; + void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const; + +private: + SmallVector matchDisjunction; + llvm::Optional replacement; +}; + +/// +/// Linalg tiling patterns. +/// +struct LinalgBaseTilingPattern : public RewritePattern { + LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, + ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + LinalgMarker marker = LinalgMarker(), + int benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + LinalgMarker marker; + SmallVector tileSizes; + SmallVector interchangeVector; +}; + +template +struct LinalgTilingPattern : public LinalgBaseTilingPattern { + LinalgTilingPattern(MLIRContext *context, ArrayRef tileSizes, + ArrayRef interchangeVector = {}, + LinalgMarker marker = LinalgMarker(), int benefit = 1) + : LinalgBaseTilingPattern(OpTy::getOperationName(), context, tileSizes, + interchangeVector, marker, benefit) {} + LinalgTilingPattern(MLIRContext *context, ArrayRef tileSizes, + LinalgMarker marker, int benefit = 1) + : LinalgTilingPattern(context, tileSizes, {}, marker, benefit) {} +}; + +/// +/// Linalg interchange patterns. +/// +struct LinalgBaseInterchangePattern : public RewritePattern { + LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + int benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + LinalgMarker marker; + SmallVector interchangeVector; +}; + +template +struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { + LinalgInterchangePattern(MLIRContext *context, + ArrayRef interchangeVector, + LinalgMarker marker = LinalgMarker(), + int benefit = 1) + : LinalgBaseInterchangePattern(OpTy::getOperationName(), context, + interchangeVector, marker, benefit) {} +}; + +/// +/// Linalg promotion patterns. +/// +struct LinalgBasePromotionPattern : public RewritePattern { + LinalgBasePromotionPattern(StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), + int benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + LinalgMarker marker; + SmallVector operandsToPromote; + unsigned alignment; +}; + +template +struct LinalgPromotionPattern : public LinalgBasePromotionPattern { + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote = {}, + unsigned alignment = 0, + LinalgMarker marker = LinalgMarker(), int benefit = 1) + : LinalgBasePromotionPattern(OpTy::getOperationName(), context, + operandsToPromote, alignment, marker, + benefit) {} + LinalgPromotionPattern(MLIRContext *context, + ArrayRef operandsToPromote, + LinalgMarker marker = LinalgMarker(), int benefit = 1) + : LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) { + } + LinalgPromotionPattern(MLIRContext *context, unsigned alignment, + LinalgMarker marker = LinalgMarker(), int benefit = 1) + : LinalgPromotionPattern(context, {}, alignment, marker, benefit) {} + LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker, + int benefit = 1) + : LinalgPromotionPattern(context, {}, 0, marker, benefit) {} +}; + +/// +/// Linalg vectorization patterns. +/// +struct LinalgBaseVectorizationPattern : public RewritePattern { + LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + int benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + LinalgMarker marker; +}; + +template +struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { + LinalgVectorizationPattern(MLIRContext *context, + LinalgMarker marker = LinalgMarker(), + int benefit = 1) + : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context, + marker, benefit) {} +}; + +/// +/// Linalg lowering patterns. +/// +enum class LinalgLoweringType { + LibraryCall = 0, + Loops = 1, + AffineLoops = 2, + ParallelLoops = 3 +}; +template +struct LinalgLoweringPattern : public RewritePattern { + LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, + LinalgMarker marker = LinalgMarker(), int benefit = 1) + : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + marker(marker), loweringType(loweringType) {} + // TODO: Move implementation to .cpp once named ops are auto-generated. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(promoteSubviewsLinalgOpPrecondition(op))) + return failure(); + + if (loweringType == LinalgLoweringType::LibraryCall) { + // TODO: Move lowering to library calls here. + return failure(); + } else if (loweringType == LinalgLoweringType::Loops) { + if (failed(linalgOpToLoops(rewriter, op))) + return failure(); + } else if (loweringType == LinalgLoweringType::AffineLoops) { + if (failed(linalgOpToAffineLoops(rewriter, op))) + return failure(); + } else { + if (failed(linalgOpToParallelLoops(rewriter, op))) + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + +private: + LinalgMarker marker; + LinalgLoweringType loweringType; +}; + +} // namespace linalg +} // namespace mlir + +#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,9 +1,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms Fusion.cpp - LinalgTransforms.cpp - LinalgToLoops.cpp + Interchange.cpp + Loops.cpp Promotion.cpp Tiling.cpp + Transforms.cpp + Vectorization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg @@ -11,7 +13,6 @@ DEPENDS intrinsics_gen MLIRLinalgPassIncGen - MLIRLinalgTransformPatternsIncGen ) target_link_libraries(MLIRLinalgTransforms PUBLIC diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -0,0 +1,82 @@ +//===- Interchange.cpp - Linalg interchange transformation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg interchange transformation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-interchange" + +using namespace mlir; +using namespace mlir::linalg; + +LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( + Operation *op, ArrayRef permutation) { + if (permutation.empty()) + return failure(); + // Transformation applies to generic ops only. + if (!isa(op) && !isa(op)) + return failure(); + LinalgOp linOp = cast(op); + // Transformation applies to buffers only. + if (!linOp.hasBufferSemantics()) + return failure(); + // Permutation must be applicable. + if (linOp.getIndexingMap(0).getNumInputs() != permutation.size()) + return failure(); + // Permutation map must be invertible. + if (!inversePermutation( + AffineMap::getPermutationMap(permutation, op->getContext()))) + return failure(); + return success(); +} + +LinalgOp mlir::linalg::interchange(LinalgOp op, + ArrayRef interchangeVector) { + MLIRContext *context = op.getContext(); + auto permutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, context)); + assert(permutationMap && "expected permutation to be invertible"); + SmallVector newIndexingMaps; + auto indexingMaps = op.indexing_maps().getValue(); + for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) { + AffineMap m = indexingMaps[i].cast().getValue(); + if (!permutationMap.isEmpty()) + m = m.compose(permutationMap); + newIndexingMaps.push_back(AffineMapAttr::get(m)); + } + auto itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + for (unsigned i = 0, e = itTypes.size(); i != e; ++i) + itTypesVector.push_back(itTypes[i]); + applyPermutationToVector(itTypesVector, interchangeVector); + + op.setAttr(getIndexingMapsAttrName(), + ArrayAttr::get(newIndexingMaps, context)); + op.setAttr(getIteratorTypesAttrName(), + ArrayAttr::get(itTypesVector, context)); + + return op; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ /dev/null @@ -1,381 +0,0 @@ -//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements logic for transforming Linalg operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "linalg-transforms" - -using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; -using namespace mlir::linalg; - -using llvm::dbgs; -using llvm::SetVector; - -// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = - "__internal_linalg_transform__"; - -using TileFn = Optional(OpBuilder &, LinalgOp, ArrayRef, - ArrayRef, OperationFolder *); - -static LogicalResult -tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - StringRef linalgMarker, - ArrayRef permutation) { - assert(permutation.empty() || permutation.size() == sizes.size()); - auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return success(); -} - -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes, - linalgMarker, permutation); -} -LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { - return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op, - sizes, linalgMarker, permutation); -} - -static LogicalResult -tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, - Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, - StringRef linalgMarker) { - auto tileRes = - tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr); - if (!tileRes) - return failure(); - tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, op->getParentOfType()); - SmallVector originalProducers; - for (auto operandIdx : operandIndicesToFuse) { - auto fusionRes = fuseProducerOf(rewriter, tileRes->op, operandIdx, G); - if (!fusionRes) { - // Linalg fusion requires tiled loops to even determine whether it is - // possible to fuse. As a consequence, the pattern may fail even though a - // tiled version of op has already been introduced. - // So we need to remove the tiled version ourselves in case of failure. - // Another possibility is to ensure the constraints on the pattern - // guarantee that fusion will occur and just assert here. As we develop - // more complex patterns we can choose what is best. - rewriter.eraseOp(tileRes->loops[0]); - return failure(); - } - fusionRes->fusedProducer.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - originalProducers.push_back(fusionRes->originalProducer); - } - - // The originalProducers can now be safely erased. This is similar to - // SSA-value use-def but in the world of buffer + structured ops. - for (auto *originalProducer : originalProducers) - rewriter.eraseOp(originalProducer); - return success(); -} - -LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker); -} -LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - return tileAndFuseLinalgOpAndSetMarkerImpl( - tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse, - linalgMarker); -} - -bool mlir::linalg::detail::isProducedByOpOfTypeImpl( - Operation *consumerOp, Value consumedView, - function_ref isaOpType) { - LinalgOp consumer = dyn_cast(consumerOp); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (!consumer) - return false; - - auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); - if (!maybeConsumerIndex) - return false; - - Aliases aliases; - auto G = LinalgDependenceGraph::buildDependenceGraph( - aliases, consumer.getParentOfType()); - for (auto dependence : G.getDependencesInto( - consumer, LinalgDependenceGraph::DependenceType::RAW)) { - auto producer = cast(dependence.dependentOpView.op); - if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) - continue; - if (isaOpType(dependence.dependentOpView.op)) - return true; - } - return false; -} - -//============================================================================// -// Precondition and transformation for vectorization of Linalg generic ops. -//============================================================================// -static bool hasMultiplyAddBody(linalg::GenericOp op) { - auto &r = op.region(); - if (r.empty()) - return false; - if (r.getBlocks().size() != 1) - return false; - auto &ops = r.front().getOperations(); - if (ops.size() != 3) - return false; - - using mlir::matchers::m_Val; - auto a = m_Val(r.front().getArgument(0)); - auto b = m_Val(r.front().getArgument(1)); - auto c = m_Val(r.front().getArgument(2)); - // TODO(ntv) Update this detection once we have matcher support for - // specifying that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); - return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || - pattern3.match(&ops.back()) || pattern4.match(&ops.back()); -} - -// TODO(ntv) should be Tablegen'd from a single source that generates the op -// itself. -static bool isRowMajorMatmul(linalg::GenericOp genericOp) { - return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - isRowMajorMatmul(genericOp.indexing_maps()) && - hasMultiplyAddBody(genericOp); -} - -// TODO(ntv, ataei): This is in fact much more general than just vectorization -// for matmul and fill ops. -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { - auto linalgOp = cast(op); - // All types must be static shape to go to vector. - for (Value operand : linalgOp.getInputsAndOutputBuffers()) - if (!operand.getType().cast().hasStaticShape()) - return failure(); - for (Type outputTensorType : linalgOp.getOutputTensorTypes()) - if (!outputTensorType.cast().hasStaticShape()) - return failure(); - if (isa(op) || isa(op)) - return success(); - - auto genericOp = dyn_cast(op); - if (!genericOp || !::isRowMajorMatmul(genericOp)) - return failure(); - - // TODO(ntv): non-identity layout. - auto isStaticMemRefWithIdentityLayout = [](Value v) { - auto m = v.getType().dyn_cast(); - if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) - return false; - return true; - }; - if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), - isStaticMemRefWithIdentityLayout)) - return failure(); - return success(); -} - -SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, - Operation *op) { - assert(succeeded(vectorizeLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - auto linalgOp = cast(op); - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - edsc::ScopedContext scope(rewriter, op->getLoc()); - - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg.fill as vector.broadcast: " - << *op << ":\n"); - auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); - Value dstVec = std_load(dstMemrefVec); - auto resVec = vector_broadcast(dstVec.getType(), fillOp.value()); - std_store(resVec, dstMemrefVec); - } else { - // Vectorize other ops as vector contraction (currently only matmul). - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); - auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); - auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); - auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - std_store(vRes, vectorMemRefC); - } - return {}; -} - -//============================================================================// -// Precondition and transformation for permutation of Linalg generic ops. -//============================================================================// -LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( - Operation *op, ArrayRef permutation) { - if (permutation.empty()) - return failure(); - // Transformation applies to generic ops only. - if (!isa(op) && !isa(op)) - return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, - ArrayRef permutation, - StringRef linalgMarker) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op - << ":\n"); - - assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && - "DRR failure case must be a precondition"); - - auto linOp = cast(op); - auto permutationMap = inversePermutation( - AffineMap::getPermutationMap(permutation, rewriter.getContext())); - assert(permutationMap && "expected permutation to be invertible"); - SmallVector newIndexingMap; - auto indexingMaps = linOp.indexing_maps().getValue(); - for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { - AffineMap m = indexingMaps[i].cast().getValue(); - if (!permutationMap.isEmpty()) - m = m.compose(permutationMap); - newIndexingMap.push_back(m); - } - auto itTypes = linOp.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, permutation); - op->setAttr(getIndexingMapsAttrName(), - rewriter.getAffineMapArrayAttr(newIndexingMap)); - op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector)); - op->setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); - return {}; -} - -//============================================================================// -// Precondition and transformation for Linalg subview promotion. -//============================================================================// -LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { - LinalgOp linOp = dyn_cast(op); - // Transformation applies to buffers only. - if (!linOp || !linOp.hasBufferSemantics()) - return failure(); - if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { - return isa_and_nonnull(v.getDefiningOp()); - })) - return failure(); - return success(); -} - -SmallVector -mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - LinalgOp linOp = cast(op); - SmallVector toPromote; - int64_t nBuffers = linOp.getNumInputsAndOutputBuffers(); - toPromote.reserve(nBuffers); - for (int64_t i = 0; i < nBuffers; ++i) - toPromote.push_back(i); - return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote); -} - -SmallVector mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, - ArrayRef operandIndicesToPromote, StringRef linalgMarker, - int64_t alignment) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " - << *op << ":\n"); - - assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && - "DRR failure case must be a precondition"); - - if (auto convOp = dyn_cast(op)) { - // TODO(ntv): add a level of indirection to linalg.generic. - if (convOp.padding()) - llvm_unreachable("Unexpected conv with padding"); - } - - LinalgOp linOp = cast(op); - assert(linOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - SetVector subViews; - for (int64_t index : operandIndicesToPromote) - if (auto sv = - dyn_cast_or_null(linOp.getBuffer(index).getDefiningOp())) - subViews.insert(sv); - - if (!subViews.empty()) { - auto newOp = - promoteSubViewOperands(rewriter, linOp, subViews, false, alignment); - if (!linalgMarker.empty()) - newOp.setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(linalgMarker)); - return {}; - } - llvm_unreachable("DRR failure case must be a precondition"); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp rename from mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp rename to mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -1,4 +1,4 @@ -//===- LinalgToLoops.cpp - conversion from Linalg library ops to loops-----===// +//===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,7 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -489,20 +489,6 @@ } }; -/// This struct is for factoring out the implementation and support template -/// instantiations in the following 2 cases: -/// 1. Appending to a list of patterns via RewritePatternList. -/// 2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`. -/// The implementation must work both in DRR and inside a RewritePattern. As a -/// consequence, (1) it is only allowed to emit new ops if the match is -/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an -/// encompassing pattern must take care of the erasure logic. -template -class LinalgOpToLoopsImpl { -public: - static Optional doit(Operation *op, PatternRewriter &rewriter); -}; - namespace { /// Helper struct to generate the loop nest for the op. This factored out here /// to be able to partially specialize this for different LoopTy. @@ -573,14 +559,12 @@ } // namespace template -Optional -LinalgOpToLoopsImpl::doit(Operation *op, - PatternRewriter &rewriter) { +Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { using Impl = GenerateLoopNest; using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; - ScopedContext scope(rewriter, op->getLoc()); + ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). @@ -607,7 +591,7 @@ SmallVector allIvs(nLoops); auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap, - getViewSizes(rewriter, linalgOp)); + getViewSizes(builder, linalgOp)); assert(loopRanges.size() == allIvs.size()); Impl::doit(linalgOp, loopRanges, allIvs); // Number of loop ops might be different from the number of ivs since some @@ -635,8 +619,7 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - using Impl = LinalgOpToLoopsImpl; - if (!Impl::doit(op, rewriter)) + if (!linalgOpToLoopsImpl(op, rewriter)) return failure(); rewriter.eraseOp(op); return success(); @@ -662,7 +645,7 @@ } }; -/// Populate the given list with patterns that convert from Linalg to LLVM. +/// Populate the given list with patterns that convert from Linalg to loops. template void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { RewritePatternList -Optional -mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit(op, rewriter); +Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op) { + return linalgOpToLoopsImpl(op, builder); } /// Emits a loop nest of `loop.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, - Operation *op) { +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `affine.for` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `loop.parallel` with the proper body for `op`. template -LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, +LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op) { Optional loops = - linalgLowerOpToLoops(rewriter, op); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } @@ -795,14 +777,14 @@ // need to update as soon as we add new ops. #define INSTANTIATE_LINALG_OP_TO_LOOPS(OP_TYPE) \ template LogicalResult mlir::linalg::linalgOpToLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToAffineLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template LogicalResult mlir::linalg::linalgOpToParallelLoops( \ - PatternRewriter & rewriter, Operation * op); \ + OpBuilder & builder, Operation * op); \ template Optional \ mlir::linalg::linalgLowerOpToLoops( \ - PatternRewriter & rewriter, Operation * op); + OpBuilder & builder, Operation * op); INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp) INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -264,6 +265,21 @@ op.erase(); } +LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition( + Operation *op, llvm::Optional> operandIndicesToPromote) { + LinalgOp linOp = dyn_cast(op); + // Transformation applies to buffers only. + if (!linOp || !linOp.hasBufferSemantics()) + return failure(); + for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) { + auto sv = isa_and_nonnull(en.value().getDefiningOp()); + if (sv && (!operandIndicesToPromote.hasValue() || + operandIndicesToPromote->count(en.index()))) + return success(); + } + return failure(); +} + namespace { struct LinalgPromotionPass : public LinalgPromotionBase { LinalgPromotionPass() = default; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -0,0 +1,224 @@ +//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic and helpers to expose Linalg transforms as rewrite +// patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "linalg-transforms" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +//============================================================================// +// Transformations exposed as rewrite patterns. +//============================================================================// +// Marker used as attribute name in generated Linalg rewriting transformations. +const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = + "__internal_linalg_transform__"; + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + llvm::Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement) {} + +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + StringRef replacement) + : LinalgMarker(matchDisjunction, llvm::Optional{replacement}) {} + +LogicalResult +mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { + auto attr = op->template getAttrOfType( + LinalgTransforms::kLinalgTransformMarker); + + if (!attr) { + // 1. Has no marker case and matchDisjunction is empty. + if (matchDisjunction.empty()) + return success(); + + // 2. Has no marker and matchDisjuntion matches the no-moarker case. + for (auto marker : matchDisjunction) + if (marker.empty()) + return success(); + + // 3. Has no marker but was expecting a marker. + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); + } + + // 4. Match explicit marker. + for (auto marker : matchDisjunction) + if (attr.getValue() == marker) + return success(); + + // 5. Fail to match. + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << " does not have any marker from list: "; + llvm::interleaveComma(matchDisjunction, diag); + }); +} + +void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, + Operation *op) const { + if (replacement.hasValue()) + op->setAttr(LinalgTransforms::kLinalgTransformMarker, + rewriter.getStringAttr(replacement.getValue())); + else + op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, + rewriter.getContext())); +} + +/// Linalg base tiling pattern. +mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( + StringRef opName, MLIRContext *context, ArrayRef tileSizes, + ArrayRef interchangeVector, LinalgMarker marker, int benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + tileSizes(tileSizes.begin(), tileSizes.end()), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + +LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + auto tileRes = + (true) ? tileLinalgOp(rewriter, linalgOp, tileSizes, interchangeVector) + : tileLinalgOpToParallelLoops(rewriter, linalgOp, tileSizes, + interchangeVector); + if (!tileRes) + return failure(); + + // New marker if specified. + marker.replaceLinalgMarker(rewriter, tileRes->op.getOperation()); + + rewriter.eraseOp(op); + return success(); +} + +/// Linalg base interchange pattern. +mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( + StringRef opName, MLIRContext *context, + ArrayRef interchangeVector, LinalgMarker marker, int benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + +LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) + return failure(); + + // TODO(ntv): figure out how this interplays with named ops. In particular + // this should break the named op property. + rewriter.updateRootInPlace(op, [&]() { + interchange(linalgOp, interchangeVector); + // New marker if specified. + marker.replaceLinalgMarker(rewriter, op); + }); + return success(); +} + +mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( + StringRef opName, MLIRContext *context, + ArrayRef operandsToPromote, unsigned alignment, + LinalgMarker marker, int benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker), + operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()), + alignment(alignment) {} + +LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (operandsToPromote.empty()) { + if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None))) + return failure(); + } else { + DenseSet set; + set.insert(operandsToPromote.begin(), operandsToPromote.end()); + if (failed(promoteSubviewsLinalgOpPrecondition(op, set))) + return failure(); + } + + llvm::SetVector subViews; + if (!operandsToPromote.empty()) { + for (unsigned idx : operandsToPromote) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } else { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + for (unsigned idx = 0; idx < nBuffers; ++idx) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } + + auto promotedOp = + promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false, + /*alignment=*/alignment); + marker.replaceLinalgMarker(rewriter, promotedOp.getOperation()); + rewriter.eraseOp(op); + return success(); +} + +mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( + StringRef opName, MLIRContext *context, LinalgMarker marker, int benefit) + : RewritePattern(opName, {}, benefit, context), marker(marker) {} + +LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (failed(vectorizeLinalgOpPrecondition(op))) + return failure(); + vectorizeLinalgOp(rewriter, op); + rewriter.eraseOp(op); + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -0,0 +1,140 @@ +//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the linalg dialect Vectorization transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::linalg; + +using llvm::dbgs; + +#define DEBUG_TYPE "linalg-vectorization" + +static bool hasMultiplyAddBody(linalg::GenericOp op) { + auto &r = op.region(); + if (r.empty()) + return false; + if (r.getBlocks().size() != 1) + return false; + auto &ops = r.front().getOperations(); + if (ops.size() != 3) + return false; + + using mlir::matchers::m_Val; + auto a = m_Val(r.front().getArgument(0)); + auto b = m_Val(r.front().getArgument(1)); + auto c = m_Val(r.front().getArgument(2)); + // TODO(ntv) Update this detection once we have matcher support for + // specifying that any permutation of operands matches. + auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || + pattern3.match(&ops.back()) || pattern4.match(&ops.back()); +} + +// TODO(ntv) should be Tablegen'd from a single source that generates the op +// itself. +static bool isRowMajorMatmul(linalg::GenericOp genericOp) { + return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && + isRowMajorMatmul(genericOp.indexing_maps()) && + hasMultiplyAddBody(genericOp); +} + +// TODO(ntv, ataei): This is in fact much more general than just vectorization +// for matmul and fill ops. +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { + auto linalgOp = cast(op); + // All types must be static shape to go to vector. + for (Value operand : linalgOp.getInputsAndOutputBuffers()) + if (!operand.getType().cast().hasStaticShape()) + return failure(); + for (Type outputTensorType : linalgOp.getOutputTensorTypes()) + if (!outputTensorType.cast().hasStaticShape()) + return failure(); + if (isa(op) || isa(op)) + return success(); + + auto genericOp = dyn_cast(op); + if (!genericOp || !::isRowMajorMatmul(genericOp)) + return failure(); + + // TODO(ntv): non-identity layout. + auto isStaticMemRefWithIdentityLayout = [](Value v) { + auto m = v.getType().dyn_cast(); + if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) + return false; + return true; + }; + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), + isStaticMemRefWithIdentityLayout)) + return failure(); + return success(); +} + +Operation *mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { + if (failed(vectorizeLinalgOpPrecondition(op))) + return nullptr; + + if (auto convOp = dyn_cast(op)) { + // TODO(ntv): add a level of indirection to linalg.generic. + if (convOp.padding()) + llvm_unreachable("Unexpected conv with padding"); + } + + edsc::ScopedContext scope(builder, op->getLoc()); + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.fill as vector.broadcast: " + << *op << ":\n"); + auto memref = vector_type_cast(fillOp.getOutputBuffer(0)); + Value dst = std_load(memref); + Value res = vector_broadcast(dst.getType(), fillOp.value()); + std_store(res, memref); + return res.getDefiningOp(); + } + + // Vectorize other ops as vector contraction (currently only matmul). + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + auto linalgOp = cast(op); + auto a = std_load(vector_type_cast(linalgOp.getInput(0))); + auto b = std_load(vector_type_cast(linalgOp.getInput(1))); + auto memref = vector_type_cast(linalgOp.getOutputBuffer(0)); + auto c = std_load(memref); + Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), + linalgOp.iterator_types()); + std_store(res, memref); + + return res.getDefiningOp(); + ; +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -25,7 +25,6 @@ // CHECK-DAG: %[[c8:.*]] = constant 8 : index // CHECK-DAG: %[[c8000:.*]] = constant 8000 : index // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { // CHECK: load // CHECK: load @@ -86,88 +85,6 @@ // CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -#some_generic_trait = { - args_in = 1, - args_out = 1, - indexing_maps = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)> - ], - iterator_types = ["parallel", "parallel"] -} -func @fusion_test(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref) { - // This should not be fused as it would violate dependencies. It will get - // tiled for all levels of the memory hierarchy. - linalg.matmul(%A, %A, %C) : memref, - memref, - memref - - // This should be fused. - linalg.matmul(%A, %B, %C) : memref, - memref, - memref - - // This should not be fused or transformed at all since there are no patterns - // on it. However it will be reordered because there are no dependencies. - linalg.generic #some_generic_trait %A, %D { - ^bb(%a: f32, %b: f32) : - linalg.yield %a : f32 - } : memref, - memref - - linalg.matmul(%C, %D, %E) : memref, - memref, - memref - - return -} -// CHECK-LABEL: func @fusion_test -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c2:.*]] = constant 2 : index -// CHECK-DAG: %[[c3:.*]] = constant 3 : index -// CHECK-DAG: %[[c4:.*]] = constant 4 : index -// CHECK-DAG: %[[c20:.*]] = constant 20 : index -// CHECK-DAG: %[[c30:.*]] = constant 30 : index -// CHECK-DAG: %[[c40:.*]] = constant 40 : index -// CHECK-DAG: %[[c100:.*]] = constant 100 : index -// CHECK-DAG: %[[c150:.*]] = constant 150 : index -// CHECK-DAG: %[[c200:.*]] = constant 200 : index -// CHECK-DAG: %[[c300:.*]] = constant 300 : index -// CHECK-DAG: %[[c400:.*]] = constant 400 : index -// CHECK-DAG: %[[c2000:.*]] = constant 2000 : index -// CHECK-DAG: %[[c3000:.*]] = constant 3000 : index -// CHECK-DAG: %[[c4000:.*]] = constant 4000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { -// CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -// -// CHECK: linalg.generic -// -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { -// CHECK: loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref - #matmul_trait = { args_in = 2, args_out = 1, @@ -280,23 +197,6 @@ // CHECK-SAME: memref, // CHECK-SAME: memref -func @dot_perm(%x: memref, - %y: memref, - %v: memref) { - linalg.dot(%x, %y, %v) {__internal_linalg_transform__ = "__with_perm__"} : - memref, - memref, - memref - return -} -// CHECK-LABEL: func @dot_perm -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %[[c8:.*]] = constant 8 : index -// CHECK-DAG: %[[c8000:.*]] = constant 8000 : index -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { -// CHECK: loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { -// CHECK: linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref - func @matvec_perm(%A: memref, %x: memref, %y: memref) { diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -1,9 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td) -mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters) -add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) -# Including Linalg in TableGen requires to depends on generated files -add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen) - set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td deleted file mode 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ /dev/null @@ -1,168 +0,0 @@ -//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the pattern definition file for declarative Linalg transformations -// tests. -// -//===----------------------------------------------------------------------===// - -#ifndef TEST_LINALG_TRANSFORMS_PATTERNS -#define TEST_LINALG_TRANSFORMS_PATTERNS - -include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" - -//===----------------------------------------------------------------------===// -// Test Linalg fusion patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $_, $_), - (TileAndFuseLinalgOp<[100, 150], [0], "L1">), - [ - (Constraint), - (Constraint> $A), - ], - // In the buffer world there is no use-def chains or dags so benefits - // cannot be computed automatically from the length of the matched - // pattern. Instead we specify the benefit ourselves for now. - // This is not expected to be a big challenge long-term because - // pattern benefits are akin to feature engineering: features should - // be learned. - (addBenefit 1)>; - -//===----------------------------------------------------------------------===// -// Linalg tiling patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L3">), - [(Constraint]>>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L2">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "L1">), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2, 3, 4], "REG">), - [(Constraint>)]>; - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1">)], - [(Constraint)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1">)], - [(Constraint, - HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg tiling and permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), - [(Constraint>)]>; -def : Pat<(MatmulOp:$op $_, $_, $_), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), - [(Constraint>)]>; - - -def : Pattern<(MatvecOp:$op $_, $_, $_), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], - [(Constraint>)]>; - -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8000], "L1__with_perm__">)], - [(Constraint>)]>; -def : Pattern<(DotOp:$op $_, $_, $_), - [(TileLinalgOp<[8], "REG__with_perm__">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to loops patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $_, $_, $_), - [(LinalgOpToLoops<"DotOp">)], - [(Constraint>)]>; - -//===----------------------------------------------------------------------===// -// Linalg to vector contraction patterns. -//===----------------------------------------------------------------------===// -def : Pattern<(MatmulOp:$op $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(FillOp:$op $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; -def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - [(VectorizeLinalgOp)], - [(Constraint, - PreconditionVectorizeLinalgOp - ]>>)]>; - - -//===----------------------------------------------------------------------===// -// Linalg generic permutation patterns. -//===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), - [(Constraint, - PreconditionPermuteGenericLinalgOp<[1, 2, 0]> - ]>>)]>; - -//===----------------------------------------------------------------------===// -// Linalg subview operands promotion. -//===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSubviewsLinalgOp), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_">]>> - )]>; - -def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">), - [(Constraint, - HasLinalgTransformMarker<"_promote_first_view_">]>> - )]>; - -def : Pat<(FillOp:$op $_, $_), - (PromoteSelectedSubviewsLinalgOp<[0], "aligned_promotion", 32>), - [(Constraint, - HasLinalgTransformMarker<"_promote_views_aligned_">]>> - )]>; - -#endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -25,7 +25,6 @@ DEPENDS MLIRStandardOpsIncGen - MLIRTestLinalgTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen ) diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -10,36 +10,132 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetVector.h" + using namespace mlir; using namespace mlir::linalg; -namespace mlir { -namespace linalg { -namespace { -#include "TestLinalgTransformPatterns.h.inc" -} // end namespace -} // end namespace linalg -} // end namespace mlir - namespace { struct TestLinalgTransforms : public PassWrapper { + TestLinalgTransforms() = default; + TestLinalgTransforms(const TestLinalgTransforms &pass) {} + void runOnFunction() override; + + Option testPatterns{*this, "test-patterns", + llvm::cl::desc("Test a mixed set of patterns"), + llvm::cl::init(false)}; }; } // end anonymous namespace -/// Apply transformations specified as patterns. -void TestLinalgTransforms::runOnFunction() { +static void applyPatterns(FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; - auto funcOp = getFunction(); - // Add the generated patterns to the list. - linalg::populateWithGenerated(&getContext(), &patterns); + //===--------------------------------------------------------------------===// + // Linalg tiling patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2000, 3000, 4000}, + LinalgMarker({"MEM", {}}, "L3")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{200, 300, 400}, + LinalgMarker({"L3"}, "L2")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{20, 30, 40}, LinalgMarker({"L2"}, "L1")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2, 3, 4}, LinalgMarker({"L1"}, "REG")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{5, 6}, LinalgMarker({}, "L1")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{8000}, + LinalgMarker({"MEM", "L3", "L2", {}}, "L1")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{8}, LinalgMarker({"L1"}, "REG")); + + //===--------------------------------------------------------------------===// + // Linalg tiling and permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{2000, 3000, 4000}, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({"__with_perm__"}, "L2__with_perm__")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{200, 300, 400}, + /*interchangeVector=*/ArrayRef{1, 0, 2}, + LinalgMarker({"L2__with_perm__"}, "L1__with_perm__")); + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{20, 30, 40}, + LinalgMarker({"L1__with_perm__"}, "REG__with_perm__")); + + patterns.insert>( + ctx, + /*tileSizes=*/ArrayRef{5, 6}, + /*interchangeVector=*/ArrayRef{1, 0}, + LinalgMarker({"__with_perm__"}, "L1__with_perm__")); + + //===--------------------------------------------------------------------===// + // Linalg to loops patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"})); + + //===--------------------------------------------------------------------===// + // Linalg to vector contraction patterns. + //===--------------------------------------------------------------------===// + patterns.insert, + LinalgVectorizationPattern, + LinalgVectorizationPattern>( + ctx, LinalgMarker({"VECTORIZE"})); + + //===--------------------------------------------------------------------===// + // Linalg generic permutation patterns. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + patterns.insert>( + ctx, + /*interchangeVector=*/ArrayRef{1, 2, 0}, + LinalgMarker({}, "PERMUTED")); + + //===--------------------------------------------------------------------===// + // Linalg subview operands promotion. + //===--------------------------------------------------------------------===// + patterns.insert>( + ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); + patterns.insert>( + ctx, + /*operandsToPromote=*/ArrayRef{0}, + /*alignment=*/32, + LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. @@ -48,9 +144,15 @@ }); } +/// Apply transformations specified as patterns. +void TestLinalgTransforms::runOnFunction() { + if (testPatterns) + return applyPatterns(getFunction()); +} + namespace mlir { void registerTestLinalgTransforms() { - PassRegistration( + PassRegistration testTransformPatternsPass( "test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); }