diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ /dev/null @@ -1,275 +0,0 @@ -//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ -#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpDefinition.h" -#include - -namespace mlir { -namespace func { -class FuncOp; -} // namespace func - -namespace linalg { - -class LinalgOp; - -/// A very primitive alias analysis which just records for each view, either: -/// 1. The base buffer, or -/// 2. The block argument view -/// that it indexes into. -/// This does not perform inter-block or inter-procedural analysis and assumes -/// that different block argument views do not alias. -class Aliases { -public: - /// Returns true if v1 and v2 alias. - bool alias(Value v1, Value v2) { return find(v1) == find(v2); } - -private: - /// Returns the base buffer or block argument into which the view `v` aliases. - /// This lazily records the new aliases discovered while walking back the - /// use-def chain. - Value find(Value v); - - DenseMap aliases; -}; - -/// Data structure for holding a dependence graph that operates on LinalgOp and -/// views as SSA values. -class LinalgDependenceGraph { -public: - enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; - // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will - // need an extension to use OpResult. - struct LinalgDependenceGraphElem { - using OpView = PointerUnion; - // dependentOpView may be either: - // 1. src in the case of dependencesIntoGraphs. - // 2. dst in the case of dependencesFromDstGraphs. - OpView dependentOpView; - // View in the op that is used to index in the graph: - // 1. src in the case of dependencesFromDstGraphs. - // 2. dst in the case of dependencesIntoGraphs. - OpView indexingOpView; - // Type of the dependence. - DependenceType dependenceType; - - // Return the Operation that owns the operand or result represented in - // `opView`. - static Operation *getOwner(OpView opView) { - if (OpOperand *operand = opView.dyn_cast()) - return operand->getOwner(); - return opView.get().cast().getOwner(); - } - // Return the operand or the result Value represented by the `opView`. - static Value getValue(OpView opView) { - if (OpOperand *operand = opView.dyn_cast()) - return operand->get(); - return opView.get(); - } - // Return the indexing map of the operand/result in `opView` specified in - // the owning LinalgOp. If the owner is not a LinalgOp returns std::nullopt. - static std::optional getIndexingMap(OpView opView) { - auto owner = dyn_cast(getOwner(opView)); - if (!owner) - return std::nullopt; - if (OpOperand *operand = opView.dyn_cast()) - return owner.getMatchingIndexingMap(operand); - return owner.getMatchingIndexingMap(owner.getDpsInitOperand( - opView.get().cast().getResultNumber())); - } - // Return the operand number if the `opView` is an OpOperand *. Otherwise - // return std::nullopt. - static std::optional getOperandNumber(OpView opView) { - if (OpOperand *operand = opView.dyn_cast()) - return operand->getOperandNumber(); - return std::nullopt; - } - // Return the result number if the `opView` is an OpResult. Otherwise return - // std::nullopt. - static std::optional getResultNumber(OpView opView) { - if (OpResult result = opView.dyn_cast().cast()) - return result.getResultNumber(); - return std::nullopt; - } - - // Return the owner of the dependent OpView. - Operation *getDependentOp() const { return getOwner(dependentOpView); } - - // Return the owner of the indexing OpView. - Operation *getIndexingOp() const { return getOwner(indexingOpView); } - - // Return the operand or result stored in the dependentOpView. - Value getDependentValue() const { return getValue(dependentOpView); } - - // Return the operand or result stored in the indexingOpView. - Value getIndexingValue() const { return getValue(indexingOpView); } - - // If the dependent OpView is an operand, return operand number. Return - // std::nullopt otherwise. - std::optional getDependentOpViewOperandNum() const { - return getOperandNumber(dependentOpView); - } - - // If the indexing OpView is an operand, return operand number. Return - // std::nullopt otherwise. - std::optional getIndexingOpViewOperandNum() const { - return getOperandNumber(indexingOpView); - } - - // If the dependent OpView is a result value, return the result - // number. Return std::nullopt otherwise. - std::optional getDependentOpViewResultNum() const { - return getResultNumber(dependentOpView); - } - - // If the dependent OpView is a result value, return the result - // number. Return std::nullopt otherwise. - std::optional getIndexingOpViewResultNum() const { - return getResultNumber(indexingOpView); - } - - // Return the indexing map of the operand/result in the dependent OpView as - // specified in the owner of the OpView. - std::optional getDependentOpViewIndexingMap() const { - return getIndexingMap(dependentOpView); - } - - // Return the indexing map of the operand/result in the indexing OpView as - // specified in the owner of the OpView. - std::optional getIndexingOpViewIndexingMap() const { - return getIndexingMap(indexingOpView); - } - }; - using LinalgDependences = SmallVector; - using DependenceGraph = DenseMap; - using dependence_iterator = LinalgDependences::const_iterator; - using dependence_range = iterator_range; - - static StringRef getDependenceTypeStr(DependenceType depType); - - // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. - static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, - func::FuncOp f); - LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); - - /// Returns the X such that op -> X is a dependence of type dt. - dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; - dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const; - - /// Returns the X such that X -> op is a dependence of type dt. - dependence_range getDependencesInto(Operation *dst, DependenceType dt) const; - dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const; - - /// Returns the operations that are interleaved between `srcLinalgOp` and - /// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence - /// relation with `srcLinalgOp`, on any view. - /// Any such operation prevents reordering. - SmallVector - findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const; - - /// Returns the operations that are interleaved between `srcLinalgOp` and - /// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`. - /// Dependences are restricted to views aliasing `view`. - SmallVector findCoveringReads(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, - Value view) const; - - /// Returns the operations that are interleaved between `srcLinalgOp` and - /// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`. - /// Dependences are restricted to views aliasing `view`. - SmallVector findCoveringWrites(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, - Value view) const; - - /// Returns true if the two operations have the specified dependence from - /// `srcLinalgOp` to `dstLinalgOp`. - bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - ArrayRef depTypes = { - DependenceType::RAW, DependenceType::WAW}) const; - - /// Returns true if the `linalgOp` has dependences into it. - bool hasDependentOperationsInto(LinalgOp linalgOp, - ArrayRef depTypes = { - DependenceType::RAW, - DependenceType::WAW}) const; - - /// Returns true if the `linalgOp` has dependences from it. - bool hasDependentOperationsFrom(LinalgOp linalgOp, - ArrayRef depTypes = { - DependenceType::RAW, - DependenceType::WAW}) const; - - /// Returns true if the `linalgOp` has dependences into or from it. - bool hasDependentOperations(LinalgOp linalgOp, - ArrayRef depTypes = { - DependenceType::RAW, - DependenceType::WAW}) const; - - /// Returns all operations that have a dependence into `linalgOp` of types - /// listed in `depTypes`. - SmallVector getDependentOperationsInto( - LinalgOp linalgOp, ArrayRef depTypes = { - DependenceType::RAW, DependenceType::WAW}) const; - - /// Returns all operations that have a dependence from `linalgOp` of types - /// listed in `depTypes`. - SmallVector getDependentOperationsFrom( - LinalgOp linalgOp, ArrayRef depTypes = { - DependenceType::RAW, DependenceType::WAW}) const; - - /// Returns all dependent operations (into and from) given `operation`. - SmallVector - getDependentOperations(LinalgOp linalgOp, - ArrayRef depTypes = { - DependenceType::RAW, DependenceType::WAW}) const; - - void print(raw_ostream &os) const; - - void dump() const; - -private: - // Keep dependences in both directions, this is not just a performance gain - // but it also reduces usage errors. - // Dependence information is stored as a map of: - // (source operation -> LinalgDependenceGraphElem) - DependenceGraph dependencesFromGraphs[DependenceType::NumTypes]; - // Reverse dependence information is stored as a map of: - // (destination operation -> LinalgDependenceGraphElem) - DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes]; - - /// Analyses the aliasing views between `src` and `dst` and inserts the proper - /// dependences in the graph. - void addDependencesBetween(LinalgOp src, LinalgOp dst); - - // Adds an new dependence unit in the proper graph. - // Uses std::pair to keep operations and view together and avoid usage errors - // related to src/dst and producer/consumer terminology in the context of - // dependences. - void addDependenceElem(DependenceType dt, - LinalgDependenceGraphElem::OpView indexingOpView, - LinalgDependenceGraphElem::OpView dependentOpView); - - /// Implementation detail for findCoveringxxx. - SmallVector - findOperationsWithCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp, Value view, - ArrayRef types) const; - - Aliases &aliases; - SmallVector linalgOps; - DenseMap linalgOpPositions; -}; -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -9,10 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_UTILS_UTILS_H #define MLIR_DIALECT_LINALG_UTILS_UTILS_H -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringSet.h" #include @@ -27,7 +25,6 @@ } // namespace tensor namespace linalg { -class LinalgDependenceGraph; //===----------------------------------------------------------------------===// // General utilities @@ -153,19 +150,6 @@ ParallelLoops = 2 }; -/// Checks whether the specific `producer` is the last write to exactly the -/// whole `consumedView`. This checks structural dominance, that the dependence -/// is a RAW without any interleaved write to any piece of `consumedView`. -bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, - LinalgOp consumer, Value consumedView, - LinalgOp producer); - -/// Checks whether fusing the specific `producer` of the `consumedView` is -/// feasible. This checks `producer` is the last write of `consumedView` and -/// that no interleaved dependence would be violated (RAW, WAR or WAW). -bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, - Value consumedView, LinalgOp producer); - /// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a /// tile size is zero (i.e., no tiling), the corresponding offset is also zero. SmallVector computeTileOffsets(OpBuilder &b, Location loc, @@ -268,13 +252,6 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef offests); -using FusableOpDependencesTy = llvm::MapVector< - Operation *, - SmallVector>; -FusableOpDependencesTy -findAllFusableDependences(ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph); - /// A struct containing the Linalg producer before and after fusion. /// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op /// before the consumer Linalg op, until enough canonicalizations have applied. @@ -283,14 +260,6 @@ LinalgOp fusedProducer; }; -/// Fuses producer into consumer if the producer is structurally feasible and -/// the fusion would not violate dependencies. -/// Implements the fusion part of the "tileAndFuse on buffers" transformation -/// and thus requires the `consumerOpOperand` to be a `subview` op (generally -/// obtained by applying the tiling transformation). -FailureOr fuseProducerOfBuffer(OpBuilder &b, - OpOperand &consumerOpOperand, - const LinalgDependenceGraph &graph); /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" /// transformation and thus requires the `consumerOpOperand` to be a diff --git a/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_dialect_library(MLIRLinalgAnalysis - DependenceAnalysis.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg - - LINK_LIBS PUBLIC - MLIRAffineAnalysis - MLIRAnalysis - MLIRIR - MLIRLinalgDialect - MLIRMemRefDialect - ) diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ /dev/null @@ -1,366 +0,0 @@ -//===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements view-based alias and dependence analyses. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/BuiltinOps.h" - -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "linalg-dependence-analysis" - -using namespace mlir; -using namespace mlir::linalg; - -using llvm::dbgs; - -Value Aliases::find(Value v) { - if (v.isa()) - return v; - - auto it = aliases.find(v); - if (it != aliases.end()) { - assert(it->getSecond().getType().isa() && - "Memref expected"); - return it->getSecond(); - } - - while (true) { - if (v.isa()) - return v; - - Operation *defOp = v.getDefiningOp(); - if (!defOp) - return v; - - // Treat RegionBranchOpInterfaces like an allocate and don't try to follow - // the aliasing further. - if (isa(defOp)) - return v; - if (isa(defOp)) - return v; - - if (auto memEffect = dyn_cast(defOp)) { - // Collect all memory effects on `v`. - SmallVector effects; - memEffect.getEffectsOnValue(v, effects); - - // If we have the 'Allocate' memory effect on `v`, then `v` should be the - // original buffer. - if (llvm::any_of( - effects, [](const MemoryEffects::EffectInstance &instance) { - return isa(instance.getEffect()); - })) - return v; - } - - if (auto viewLikeOp = dyn_cast(defOp)) { - auto it = - aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource()))); - return it.first->second; - } - - llvm::errs() << "View alias analysis reduces to: " << v << "\n"; - llvm_unreachable("unsupported view alias case"); - } -} - -StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { - switch (depType) { - case LinalgDependenceGraph::DependenceType::RAW: - return "RAW"; - case LinalgDependenceGraph::DependenceType::RAR: - return "RAR"; - case LinalgDependenceGraph::DependenceType::WAR: - return "WAR"; - case LinalgDependenceGraph::DependenceType::WAW: - return "WAW"; - default: - break; - } - llvm_unreachable("Unexpected DependenceType"); -} - -LinalgDependenceGraph -LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) { - SmallVector linalgOps; - f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); - return LinalgDependenceGraph(aliases, linalgOps); -} - -LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, - ArrayRef ops) - : aliases(aliases), linalgOps(ops.begin(), ops.end()) { - for (const auto &en : llvm::enumerate(linalgOps)) { - linalgOpPositions.insert( - std::make_pair(en.value().getOperation(), en.index())); - } - for (unsigned i = 0, e = ops.size(); i < e; ++i) { - for (unsigned j = i + 1; j < e; ++j) { - addDependencesBetween(ops[i], ops[j]); - } - } -} - -void LinalgDependenceGraph::addDependenceElem( - DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView, - LinalgDependenceGraphElem::OpView dependentOpView) { - LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" - << LinalgDependenceGraphElem::getValue(indexingOpView) - << " @) -> \n\t\t(" - << LinalgDependenceGraphElem::getValue(dependentOpView) - << " @)"); - dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)] - .push_back( - LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); - dependencesIntoGraphs[dt] - [LinalgDependenceGraphElem::getOwner(dependentOpView)] - .push_back(LinalgDependenceGraphElem{ - indexingOpView, dependentOpView, dt}); -} - -LinalgDependenceGraph::dependence_range -LinalgDependenceGraph::getDependencesFrom( - LinalgOp src, LinalgDependenceGraph::DependenceType dt) const { - return getDependencesFrom(src.getOperation(), dt); -} - -LinalgDependenceGraph::dependence_range -LinalgDependenceGraph::getDependencesFrom( - Operation *src, LinalgDependenceGraph::DependenceType dt) const { - auto iter = dependencesFromGraphs[dt].find(src); - if (iter == dependencesFromGraphs[dt].end()) - return llvm::make_range(nullptr, nullptr); - return llvm::make_range(iter->second.begin(), iter->second.end()); -} - -LinalgDependenceGraph::dependence_range -LinalgDependenceGraph::getDependencesInto( - LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const { - return getDependencesInto(dst.getOperation(), dt); -} - -LinalgDependenceGraph::dependence_range -LinalgDependenceGraph::getDependencesInto( - Operation *dst, LinalgDependenceGraph::DependenceType dt) const { - auto iter = dependencesIntoGraphs[dt].find(dst); - if (iter == dependencesIntoGraphs[dt].end()) - return llvm::make_range(nullptr, nullptr); - return llvm::make_range(iter->second.begin(), iter->second.end()); -} - -void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() - << " and " << *dst.getOperation() << "\n"); - if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { - for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { - if (!dstOpOperand->get().getType().isa()) - continue; - // Check if the operand is defined by the src. - auto definingOp = dstOpOperand->get().getDefiningOp(); - if (definingOp && definingOp == src) - addDependenceElem(DependenceType::RAW, dstOpOperand->get(), - dstOpOperand); - } - for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) { - // Check if the operand is defined by the src. - auto definingOp = dstOpOperand->get().getDefiningOp(); - if (definingOp && definingOp == src) { - if (dst.isInitTensor(dstOpOperand)) { - addDependenceElem(DependenceType::RAW, dstOpOperand->get(), - dstOpOperand); - } - addDependenceElem(DependenceType::WAW, dstOpOperand->get(), - dstOpOperand); - } - } - return; - } - assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && - "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getDpsInitOperands()) { // W - // RAW graph - for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R - if (!dstOpOperand->get().getType().isa()) - continue; - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias - addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); - } - // WAW graph - for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias - addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); - } - for (OpOperand *srcOpOperand : src.getDpsInputOperands()) { // R - if (!srcOpOperand->get().getType().isa()) - continue; - // RAR graph - for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R - if (!dstOpOperand->get().getType().isa()) - continue; - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias - addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); - } - // WAR graph - for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W - if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias - addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); - } -} - -SmallVector -LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp, - LinalgOp dstLinalgOp) const { - return findOperationsWithCoveringDependences( - srcLinalgOp, dstLinalgOp, nullptr, - {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW}); -} - -SmallVector LinalgDependenceGraph::findCoveringWrites( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { - return findOperationsWithCoveringDependences( - srcLinalgOp, dstLinalgOp, view, - {DependenceType::WAW, DependenceType::WAR}); -} - -SmallVector LinalgDependenceGraph::findCoveringReads( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const { - return findOperationsWithCoveringDependences( - srcLinalgOp, dstLinalgOp, view, - {DependenceType::RAR, DependenceType::RAW}); -} - -SmallVector -LinalgDependenceGraph::findOperationsWithCoveringDependences( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view, - ArrayRef types) const { - auto *src = srcLinalgOp.getOperation(); - auto *dst = dstLinalgOp.getOperation(); - auto srcPos = linalgOpPositions.lookup(src); - auto dstPos = linalgOpPositions.lookup(dst); - assert(srcPos < dstPos && "expected dst after src in IR traversal order"); - - SmallVector res; - // Consider an intermediate interleaved `interim` op, look for any dependence - // to an aliasing view on a src -> op -> dst path. - // TODO: we are not considering paths yet, just interleaved positions. - for (auto dt : types) { - for (auto dependence : getDependencesFrom(src, dt)) { - auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp()); - // Skip if not interleaved. - if (interimPos >= dstPos || interimPos <= srcPos) - continue; - Value consumerView = dependence.getIndexingValue(); - if (view && !aliases.alias(view, consumerView)) - continue; - auto *op = dependence.getDependentOp(); - LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " - << getDependenceTypeStr(dt) << ": " << *src << " -> " - << *op << " on " << consumerView); - res.push_back(op); - } - } - return res; -} - -bool LinalgDependenceGraph::hasDependenceFrom( - LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, - ArrayRef depTypes) const { - for (auto dep : depTypes) - for (auto dependence : getDependencesInto(dstLinalgOp, dep)) - if (dependence.getDependentOp() == srcLinalgOp) - return true; - return false; -} - -bool LinalgDependenceGraph::hasDependentOperationsFrom( - LinalgOp linalgOp, - ArrayRef depTypes) const { - for (auto dep : depTypes) { - if (!getDependencesFrom(linalgOp, dep).empty()) - return true; - } - return false; -} - -bool LinalgDependenceGraph::hasDependentOperationsInto( - LinalgOp linalgOp, - ArrayRef depTypes) const { - for (auto dep : depTypes) { - if (!getDependencesInto(linalgOp, dep).empty()) - return true; - } - return false; -} - -bool LinalgDependenceGraph::hasDependentOperations( - LinalgOp linalgOp, ArrayRef depTypes) const { - return hasDependentOperationsInto(linalgOp, depTypes) || - hasDependentOperationsFrom(linalgOp, depTypes); -} - -SmallVector -LinalgDependenceGraph::getDependentOperationsInto( - LinalgOp linalgOp, ArrayRef depTypes) const { - SmallVector - dependentOperations; - for (auto dependenceType : depTypes) { - auto dependencies = getDependencesInto(linalgOp, dependenceType); - dependentOperations.append(dependencies.begin(), dependencies.end()); - } - return dependentOperations; -} - -SmallVector -LinalgDependenceGraph::getDependentOperationsFrom( - LinalgOp linalgOp, ArrayRef depTypes) const { - SmallVector - dependentOperations; - for (auto dependenceType : depTypes) { - auto dependencies = getDependencesFrom(linalgOp, dependenceType); - dependentOperations.append(dependencies.begin(), dependencies.end()); - } - return dependentOperations; -} - -/// Returns all dependent operations (into and from) given `operation`. -SmallVector -LinalgDependenceGraph::getDependentOperations( - LinalgOp linalgOp, ArrayRef depTypes) const { - SmallVector dependentOperations = - getDependentOperationsInto(linalgOp, depTypes); - SmallVector t = - getDependentOperationsFrom(linalgOp, depTypes); - dependentOperations.append(t.begin(), t.end()); - return dependentOperations; -} - -void LinalgDependenceGraph::print(raw_ostream &os) const { - for (auto dt : { - LinalgDependenceGraph::DependenceType::RAW, - LinalgDependenceGraph::DependenceType::WAW, - }) { - const auto &fromGraph = dependencesFromGraphs[dt]; - for (const auto &it : fromGraph) { - os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first - << ":\n"; - for (const auto &dep : it.second) { - os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n"; - } - } - } -} - -void LinalgDependenceGraph::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -56,7 +56,6 @@ MLIRMemRefDialect MLIRMemRefTransforms MLIRLinalgDialect - MLIRLinalgAnalysis MLIRLinalgUtils MLIRSCFDialect MLIRSCFTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -204,173 +203,6 @@ return fuse(b, producerOp, fusedLoopsAndRanges); } -// Encode structural fusion safety preconditions. -// Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, - LinalgOp consumer) { - assert(producer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (producer.getNumDpsInits() != 1) { - LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); - return false; - } - // Only fuse when the producer block dominates. - DominanceInfo dom(producer.getOperation()); - if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { - LLVM_DEBUG( - llvm::dbgs() - << "\nNot structurally fusable (producer block does not dominate)"); - return false; - } - return true; -} - -bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, - LinalgOp consumer, - Value consumedView, - LinalgOp producer) { - assert(producer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - // Make some simple structural checks that alleviate the need for more - // complex analyses. - if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { - LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" - << *producer.getOperation()); - return false; - } - // Check for any interleaved write to consumedView. - if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { - LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" - << *producer.getOperation()); - return false; - } - return true; -} - -bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, - LinalgOp consumer, Value consumedView, - LinalgOp producer) { - assert(producer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) - return false; - // Check for any fusion-preventing dependence to any shape read/written that - // would violate dependences. - if (!graph.findCoveringDependences(producer, consumer).empty()) { - LLVM_DEBUG(llvm::dbgs() - << "\n***Not fusable due to an interleaved dependence:\t" - << *producer.getOperation()); - return false; - } - return true; -} - -/// For `consumer` with buffer semantics, find the Linalg operation on buffers -/// that is the last writer of `consumerOpOperand`. For now the fusable -/// dependence is returned as an instance of the `dependenceGraph`. -static FailureOr -findFusableProducer(OpOperand &consumerOpOperand, - const LinalgDependenceGraph &dependenceGraph) { - LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " - << consumerOpOperand.get() << " @" - << consumerOpOperand.getOperandNumber() << " in " - << *consumerOpOperand.getOwner() << "\n"); - LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); - if (!consumerOp) - return failure(); - - // Only consider RAW and WAW atm. - for (auto depType : { - LinalgDependenceGraph::DependenceType::RAW, - LinalgDependenceGraph::DependenceType::WAW, - }) { - LLVM_DEBUG(llvm::dbgs() - << "Dependencies into: " << *consumerOp.getOperation() << "\n"); - for (auto dependence : llvm::make_filter_range( - dependenceGraph.getDependencesInto(consumerOp, depType), - [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { - LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " - << elem.getIndexingValue() << " and " - << elem.getDependentValue() << "\n"); - Value v = elem.getIndexingValue(); - std::optional operandNum = - elem.getIndexingOpViewOperandNum(); - return isa(elem.getDependentOp()) && - v == consumerOpOperand.get() && operandNum && - *operandNum == consumerOpOperand.getOperandNumber(); - })) { - // Consumer consumes this view, `isStructurallyFusableProducer` also - // checks whether it is a strict subview of the producer view. - auto producer = cast(dependence.getDependentOp()); - LLVM_DEBUG(llvm::dbgs() - << "\n" - << LinalgDependenceGraph::getDependenceTypeStr(depType) - << "producer: " << *dependence.getDependentOp() - << " view: " << dependence.getDependentValue() << "\n"); - - // If the producer and consumer have tensor semantics, the only dependence - // between them is through a RAW dependence and they are fusable by - // construction. For buffer semantics need additional checks. - if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && - isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), - producer)) - return dependence; - if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { - assert(dependence.dependenceType == - LinalgDependenceGraph::DependenceType::RAW); - return dependence; - } - } - } - return failure(); -} - -FailureOr -mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, - const LinalgDependenceGraph &graph) { - std::optional - fusableDependence = findFusableProducer(consumerOpOperand, graph); - if (!fusableDependence) - return failure(); - - LinalgOp producerOp = dyn_cast(fusableDependence->getDependentOp()); - if (!producerOp) - return failure(); - - // If producer is already in the same block as consumer, we are done. - if (consumerOpOperand.get().getParentBlock() == - fusableDependence->getDependentValue().getParentBlock()) - return failure(); - - std::optional producerMap = - fusableDependence->getDependentOpViewIndexingMap(); - if (!producerMap) - return failure(); - - // Must be a subview or an extract_slice to guarantee there are loops we can - // fuse into. - auto subView = consumerOpOperand.get().getDefiningOp(); - if (!subView) { - LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); - return failure(); - } - - // Fuse `producer` just before `consumer`. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(consumerOpOperand.getOwner()); - LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " - << *consumerOpOperand.getOwner() << "\n"); - - auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); - return FusionInfo{producerOp, fusedProducer}; -} - /// Walk back use-def chain through scf::For yields. /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s - -func.func @f1(%A: memref>, %B: memref>, %C: memref>, %D: memref>, %E: memref>) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %c40 = arith.constant 40 : index - %c30 = arith.constant 30 : index - %c20 = arith.constant 20 : index - %0 = memref.dim %C, %c0 : memref> - %1 = memref.dim %C, %c1 : memref> - %2 = memref.dim %D, %c1 : memref> - linalg.matmul ins(%A, %B: memref>, memref>) - outs(%C: memref>) - scf.for %arg5 = %c0 to %0 step %c20 { - scf.for %arg6 = %c0 to %2 step %c30 { - scf.for %arg7 = %c0 to %1 step %c40 { - %5 = memref.subview %C[%arg5, %arg7][%c20, %c40][%c1, %c1] : memref> to memref> - %7 = memref.subview %D[%arg7, %arg6][%c40, %c30][%c1, %c1]: memref> to memref> - %8 = memref.subview %E[%arg5, %arg6][%c20, %c40][%c1, %c1] : memref> to memref> - %9 = memref.dim %5, %c0 : memref> - %10 = memref.dim %5, %c1 : memref> - %11 = memref.dim %7, %c1 : memref> - scf.for %arg8 = %c0 to %9 step %c2 { - scf.for %arg9 = %c0 to %11 step %c3 { - scf.for %arg10 = %c0 to %10 step %c4 { - %14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref> to memref> - %16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref> to memref> - %17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref> to memref> - linalg.matmul ins(%14, %16: memref>, memref>) - outs(%17: memref>) - } - } - } - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f1 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: linalg.matmul diff --git a/mlir/test/Dialect/Linalg/fusion-indexed.mlir b/mlir/test/Dialect/Linalg/fusion-indexed.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion-indexed.mlir +++ /dev/null @@ -1,160 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s - -#id_2d = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = { - indexing_maps = [#id_2d, #id_2d, #id_2d], - iterator_types = ["parallel", "parallel"] -} -func.func @fuse_indexed_consumer(%A: memref, - %B: memref, - %C: memref, - %D: memref) { - linalg.generic #pointwise_2d_trait - ins(%A, %B: memref, memref) - outs(%C : memref) { - ^bb0(%e: f32, %arg5: f32, %arg6: f32): - %2 = arith.addf %e, %arg5 : f32 - linalg.yield %2 : f32 - } - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c25 = arith.constant 25 : index - %c10 = arith.constant 10 : index - %0 = memref.dim %C, %c0 : memref - %1 = memref.dim %C, %c1 : memref - %2 = memref.dim %D, %c0 : memref - %3 = memref.dim %D, %c1 : memref - scf.for %arg2 = %c0 to %0 step %c10 { - scf.for %arg3 = %c0 to %1 step %c25 { - %4 = memref.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] : - memref to memref> - %5 = memref.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] : - memref to memref> - linalg.generic { - indexing_maps = [#id_2d, #id_2d], - iterator_types = ["parallel", "parallel"]} - ins(%4 : memref>) - outs(%5 : memref>) { - ^bb0(%arg4: f32, %arg5: f32): - %idx0 = linalg.index 0 : index - %idx1 = linalg.index 1 : index - %6 = arith.addi %idx0, %arg2 : index - %7 = arith.addi %idx1, %arg3 : index - %8 = arith.index_cast %6 : index to i32 - %9 = arith.sitofp %8 : i32 to f32 - %10 = arith.index_cast %7 : index to i32 - %11 = arith.sitofp %10 : i32 to f32 - %12 = arith.addf %9, %11 : f32 - %13 = arith.addf %12, %arg4 : f32 - linalg.yield %13 : f32 - } - } - } - return -} -// CHECK-LABEL: func @fuse_indexed_consumer -// CHECK: scf.for -// CHECK: scf.for -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK-NOT: affine.apply -// CHECK: arith.addf -// CHECK: linalg.generic -// CHECK: arith.index_cast - -// ----- - -func.func @fuse_indexed_producer(%A: memref, - %B: memref) { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c25 = arith.constant 25 : index - %c10 = arith.constant 10 : index - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (j, i)>], - iterator_types = ["parallel", "parallel"]} - outs(%A : memref) { - ^bb0(%a: index): - %idx0 = linalg.index 0 : index - %idx1 = linalg.index 1 : index - %0 = arith.addi %idx0, %idx1 : index - linalg.yield %0 : index - } - %A_X = memref.dim %A, %c0 : memref - %A_Y = memref.dim %A, %c1 : memref - scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%A_X, %A_Y) step (%c10, %c25) { - %A_view = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] : - memref to memref> - %B_view = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] : - memref to memref> - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - ins(%A_view : memref>) - outs(%B_view : memref>) { - ^bb0(%a: index, %b: index): - linalg.yield %a : index - } - } - return -} -// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-LABEL: func @fuse_indexed_producer -// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = -// CHECK: linalg.generic -// CHECK: [[idx0:%.*]] = linalg.index 0 : index -// CHECK: [[i_new:%.*]] = affine.apply [[$MAP]]([[idx0]], [[J]]) -// CHECK: [[idx1:%.*]] = linalg.index 1 : index -// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[I]]) -// CHECK: [[sum:%.*]] = arith.addi [[i_new]], [[j_new]] : index -// CHECK: linalg.yield [[sum]] : index -// CHECK: linalg.generic - -// ----- - -func.func @fuse_indexed_producer_tiled_second_dim_only(%A: memref, - %B: memref) { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c25 = arith.constant 25 : index - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%A : memref) { - ^bb0(%a: index): - %idx0 = linalg.index 0 : index - %idx1 = linalg.index 1 : index - %0 = arith.addi %idx0, %idx1 : index - linalg.yield %0 : index - } - %A_X = memref.dim %A, %c0 : memref - %A_Y = memref.dim %A, %c1 : memref - scf.parallel (%arg3) = (%c0) to (%A_Y) step (%c25) { - %A_view = memref.subview %A[%c0, %arg3][%A_X, %c25][%c1, %c1] : - memref to memref> - %B_view = memref.subview %B[%c0, %arg3][%A_X, %c25][%c1, %c1] : - memref to memref> - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - ins(%A_view : memref>) - outs(%B_view : memref>) { - ^bb0(%a: index, %b: index): - linalg.yield %a : index - } - } - return -} -// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-LABEL: func @fuse_indexed_producer_tiled_second_dim_only -// CHECK: scf.parallel ([[J:%.*]]) = -// CHECK: linalg.generic -// CHECK: [[idx0:%.*]] = linalg.index 0 : index -// CHECK: [[idx1:%.*]] = linalg.index 1 : index -// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[J]]) -// CHECK: [[sum:%.*]] = arith.addi [[idx0]], [[j_new]] : index -// CHECK: linalg.yield [[sum]] : index -// CHECK: linalg.generic - diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ /dev/null @@ -1,745 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s - -func.func @f1(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %A, %c0 : memref> - %1 = memref.dim %A, %c1 : memref> - %2 = memref.dim %B, %c1 : memref> - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8: memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f1 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: linalg.matmul - -// ----- - -func.func @f2(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C: memref>) - %0 = memref.dim %C, %c0 : memref> - %1 = memref.dim %C, %c1 : memref> - %2 = memref.dim %D, %c1 : memref> - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f2 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %c0{{[_0-9]*}} : memref> -// CHECK-DAG: %[[C_1:.*]] = memref.dim %[[C]], %c1{{[_0-9]*}} : memref> -// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %c1{{[_0-9]*}} : memref> -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// CHECK: linalg.matmul -// CHECK: linalg.matmul - -// ----- - -func.func @f3(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - %0 = memref.dim %D, %c0 : memref> - %1 = memref.dim %D, %c1 : memref> - %2 = memref.dim %C, %c1 : memref> - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %7 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f3 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref> -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref> -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1]] : memref> -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: linalg.matmul -// CHECK: linalg.matmul - -// ----- - -func.func @f4(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%D : memref>) - %0 = memref.dim %C, %c0 : memref> - %1 = memref.dim %C, %c1 : memref> - %2 = memref.dim %D, %c1 : memref> - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f4 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref> -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref> -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref> -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// Fuse D then fuse C, no false dependence prevent it. -// CHECK: linalg.matmul -// CHECK: linalg.matmul -// CHECK: linalg.matmul - -// ----- - -func.func @f5(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %0 = memref.dim %B, %c1 : memref> - %1 = memref.dim %D, %c0 : memref> - %2 = memref.dim %D, %c1 : memref> - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - linalg.matmul ins(%C, %B : memref>, - memref>) - outs(%D : memref>) - scf.for %arg5 = %c0 to %1 step %c2 { - scf.for %arg6 = %c0 to %0 step %c3 { - scf.for %arg7 = %c0 to %2 step %c4 { - %5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} - -// CHECK-DAG: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> -// CHECK-DAG: #[[BOUND_2_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 2)> -// CHECK-DAG: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> -// CHECK: func @f5 -// CHECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[A_0:.*]] = memref.dim %[[A]], %[[C0]] : memref> -// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref> -// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref> -// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref> -// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref> -// CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}} -// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { -// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]] -// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]] -// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_2_MAP_2]](%[[I]])[%[[A_0]], %[[C_0]]] -// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0] -// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]] -// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { -// CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]] -// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4] -// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]] -// CHECK: %[[BOUND_4_B1:.*]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[B_1]]] -// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]] -// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_C0]], %[[BOUND_4_B1]]] -// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]] -// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]] -// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]] - -// ----- - -#map0 = affine_map<(d0) -> (d0 + 2)> -#map1 = affine_map<(d0) -> (d0 + 4)> -#map2 = affine_map<(d0) -> (d0 + 3)> - -func.func @f6(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %0 = memref.dim %C, %c1 : memref> - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - linalg.matmul ins(%A, %C : memref>, - memref>) - outs(%E : memref>) - %1 = memref.dim %C, %c0 : memref> - %2 = memref.dim %D, %c1 : memref> - scf.for %arg5 = %c0 to %1 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %0 step %c4 { - %3 = affine.apply #map0(%arg5) - %4 = affine.apply #map1(%arg7) - %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %6 = affine.apply #map2(%arg6) - %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f6 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// Fuse the producer of E (WAW) then the producer of C (WAR). -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK: linalg.matmul -// CHECK: linalg.matmul - -// ----- - -func.func @f7(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %0 = memref.dim %A, %c0 : memref> - %1 = memref.dim %A, %c1 : memref> - %2 = memref.dim %C, %c1 : memref> - %3 = memref.dim %C, %c0 : memref> - %4 = memref.dim %D, %c1 : memref> - linalg.matmul ins(%A, %C : memref>, - memref>) - outs(%E : memref>) - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %7 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %9 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%7, %9 : memref>, - memref>) - outs(%10 : memref>) - } - } - } - scf.for %arg5 = %c0 to %3 step %c2 { - scf.for %arg6 = %c0 to %4 step %c3 { - scf.for %arg7 = %c0 to %2 step %c4 { - %7 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %9 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%7, %9 : memref>, - memref>) - outs(%10 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f7 -// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[A_0:.*]] = memref.dim %[[A]], %[[C0:.*]] : memref> -// CHECK: %[[A_1:.*]] = memref.dim %[[A]], %[[C1:.*]] : memref> -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref> -// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref> -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref> -// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]] -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { -// CHECK: linalg.matmul -// CHECK: linalg.matmul -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { -// CHECK: linalg.matmul -// CHECK-NOT: linalg.matmul - -// ----- - -#map0 = affine_map<(d0) -> (d0 + 2)> -#map1 = affine_map<(d0) -> (d0 + 4)> -#map2 = affine_map<(d0) -> (d0 + 3)> - -func.func @f8(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>, - %E: memref> - ) -> memref> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %0 = memref.dim %A, %c0 : memref> - %1 = memref.dim %A, %c1 : memref> - linalg.matmul ins(%A, %C : memref>, - memref>) - outs(%D : memref>) - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - %2 = memref.dim %D, %c1 : memref> - scf.for %arg5 = %c0 to %0 step %c2 { - scf.for %arg6 = %c0 to %2 step %c3 { - scf.for %arg7 = %c0 to %1 step %c4 { - %3 = affine.apply #map0(%arg5) - %4 = affine.apply #map1(%arg7) - %5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref> to - memref> - %6 = affine.apply #map2(%arg6) - %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref> to - memref> - %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%5, %7 : memref>, - memref>) - outs(%8 : memref>) - } - } - } - return %E : memref> -} -// CHECK-LABEL: func @f8 -// CHECK: (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}}) -// CHECK: linalg.matmul -// CHECK: linalg.matmul -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.matmul -// CHECK-NOT: linalg.matmul - -// ----- - -#id_2d = affine_map<(i, j) -> (i, j)> -#pointwise_2d_trait = { - indexing_maps = [#id_2d, #id_2d, #id_2d], - iterator_types = ["parallel", "parallel"] -} -func.func @pointwise(%A: memref>, - %B: memref>, - %C: memref>, - %D: memref>) { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - linalg.generic #pointwise_2d_trait - ins(%A, %A: memref>, - memref>) - outs(%B : memref>) { - ^bb0(%E: f32, %arg5: f32, %arg6: f32): - %2 = arith.addf %E, %arg5 : f32 - linalg.yield %2 : f32 - } - %0 = memref.dim %B, %c0 : memref> - %1 = memref.dim %B, %c1 : memref> - scf.for %arg4 = %c0 to %0 step %c2 { - scf.for %arg5 = %c0 to %1 step %c3 { - %4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref> to - memref> - %5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref> to - memref> - %6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.generic #pointwise_2d_trait - ins(%4, %5: memref>, - memref>) - outs(%6 : memref>) { - ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): - %7 = arith.mulf %arg6, %arg7 : f32 - linalg.yield %7 : f32 - } - } - } - return -} -// CHECK-LABEL: func @pointwise -// CHECK: scf.for -// CHECK: scf.for -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: arith.addf -// CHECK: linalg.generic -// CHECK: arith.mulf - -// ----- - -#id_2d = affine_map<(i, j) -> (i, j)> -#pointwise_2d_trait = { - indexing_maps = [#id_2d, #id_2d, #id_2d], - iterator_types = ["parallel", "parallel"] -} -func.func @pointwise_no_view(%M: index, %N: index) { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %A = memref.alloc (%M, %N): memref - %B = memref.alloc (%M, %N): memref - %C = memref.alloc (%M, %N): memref - %D = memref.alloc (%M, %N): memref - %E = memref.alloc (%M, %N): memref - linalg.generic #pointwise_2d_trait - ins(%A, %A : memref, memref) - outs(%B : memref) { - ^bb0(%e: f32, %arg5: f32, %arg6: f32): - %2 = arith.addf %e, %arg5 : f32 - linalg.yield %2 : f32 - } - %0 = memref.dim %B, %c0 : memref - %1 = memref.dim %B, %c1 : memref - scf.for %arg4 = %c0 to %0 step %c2 { - scf.for %arg5 = %c0 to %1 step %c3 { - %4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref> - %5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref> - %6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref> - linalg.generic #pointwise_2d_trait - ins(%4, %5: memref>, - memref>) - outs(%6 : memref>) { - ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): - %7 = arith.mulf %arg6, %arg7 : f32 - linalg.yield %7 : f32 - } - } - } - return -} -// CHECK-LABEL: func @pointwise_no_view -// CHECK: scf.for -// CHECK: scf.for -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: arith.addf -// CHECK: linalg.generic -// CHECK: arith.mulf - - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @fusion_of_three(%arg0: memref<100x10xf32>, - %arg1: memref<100xf32>, - %arg2: memref<100x10xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.alloc() {temp = true} : memref<100x10xf32> - linalg.generic { - indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg1 : memref<100xf32>) - outs(%0 : memref<100x10xf32>) { - ^bb0(%arg3: f32, %arg4: f32): - linalg.yield %arg3 : f32 - } - %1 = memref.alloc() {temp = true} : memref<100x10xf32> - linalg.generic { - indexing_maps = [#map1, #map1, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>) - outs(%1 : memref<100x10xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.subf %arg3, %arg4 : f32 - linalg.yield %2 : f32 - } - memref.dealloc %0 : memref<100x10xf32> - %2 = memref.dim %1, %c0 : memref<100x10xf32> - %3 = memref.dim %1, %c1 : memref<100x10xf32> - %4 = memref.dim %arg2, %c0 : memref<100x10xf32> - %5 = memref.dim %arg2, %c1 : memref<100x10xf32> - scf.for %i = %c0 to %2 step %c1 { - scf.for %j = %c0 to %3 step %c1 { - %6 = memref.subview %1[%i, %j][%c1, %c1][%c1, %c1] : - memref<100x10xf32> to memref> - %7 = memref.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] : - memref<100x10xf32> to memref> - linalg.generic { - indexing_maps = [#map1, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%6 : memref>) - outs(%7 : memref>) { - ^bb0(%arg3: f32, %arg4: f32): - %8 = math.exp %arg3 : f32 - linalg.yield %8 : f32 - } - } - } - memref.dealloc %1 : memref<100x10xf32> - return -} -// CHECK-LABEL: func @fusion -// CHECK-NOT: linalg.generic -// CHECK: scf.for -// CHECK: scf.for -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: linalg.yield -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: linalg.yield -// CHECK: linalg.generic -// CHECK: exp -// CHECK: linalg.yield - -// ----- - - -#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)> -#map1 = affine_map<(d0)[s0] -> (3, -d0 + s0)> -#map3 = affine_map<(d0)[s0, s1] -> (s0 + 1, -d0 + s0 + s1)> -#map4 = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1)> - -func.func @fill_and_conv(%arg0: memref, %arg1: memref, %arg2: memref) { - %cst = arith.constant 0.000000e+00 : f32 - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - linalg.fill ins(%cst : f32) outs(%arg0 : memref) - %2 = memref.dim %arg1, %c0 : memref - %3 = memref.dim %arg1, %c1 : memref - %4 = memref.dim %arg2, %c0 : memref - %5 = memref.dim %arg2, %c1 : memref - scf.for %arg3 = %c0 to %4 step %c2 { - scf.for %arg4 = %c0 to %5 step %c3 { - %6 = affine.min #map3(%arg3)[%2, %4] - %7 = affine.min #map4(%arg4)[%3, %5] - %8 = memref.subview %arg0[%arg3, %arg4] [%6, %7] [1, 1] : memref to memref> - %9 = affine.min #map0(%arg3)[%4] - %10 = affine.min #map1(%arg4)[%5] - %11 = memref.subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref to memref> - linalg.conv_2d ins(%8, %arg1 : memref>, memref) outs(%11 : memref>) - } - } - return -} -// CHECK-LABEL: func @fill_and_conv -// CHECK: scf.for -// CHECK: scf.for -// CHECK: linalg.fill -// CHECK: linalg.conv_2d - -// ----- - -// Test that different allocation-like ops are recognized and properly handled. -func.func @accept_different_alloc_ops(%dim: index, %s0 : index, %s1: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - - %A = memref.alloca(%dim, %dim)[%s0, %s1] : memref> - %B = memref.alloca(%dim, %dim)[%s0, %s1] : memref> - %C = memref.alloc(%dim, %dim)[%s0, %s1] : memref> - - linalg.matmul ins(%A, %B : memref>, - memref>) - outs(%C : memref>) - - scf.for %i = %c0 to %dim step %c2 { - scf.for %j = %c0 to %dim step %c3 { - scf.for %k = %c0 to %dim step %c4 { - %0 = memref.subview %A[%i, %k][%c2, %c4][%c1, %c1] : - memref> to - memref> - %1 = memref.subview %B[%k, %j][%c4, %c3][%c1, %c1] : - memref> to - memref> - %2 = memref.subview %C[%i, %j][%c2, %c3][%c1, %c1] : - memref> to - memref> - linalg.matmul ins(%0, %1 : memref>, - memref>) - outs(%2 : memref>) - } - } - } - return -} - -// CHECK-LABEL: func @accept_different_alloc_ops -// CHECK-COUNT-3: scf.for -// CHECK-COUNT-2: linalg.matmul diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" @@ -39,22 +38,9 @@ bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (OpOperand &opOperand : linalgOp->getOpOperands()) { - if (opOperand.get().getType().isa()) { - // TODO: LinalgDependenceGraph should be able to update itself. - // The current naive and expensive reconstruction of the graph should be - // removed. - linalg::Aliases aliases; - linalg::LinalgDependenceGraph graph(aliases, linalgOps); - auto info = fuseProducerOfBuffer(b, opOperand, graph); - if (failed(info)) - continue; - auto *originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - changed = true; - } else if (opOperand.get().getType().isa()) { + if (opOperand.get().getType().isa()) + continue; + if (opOperand.get().getType().isa()) { // Tile and Fuse tensor input. if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) continue; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8384,25 +8384,6 @@ deps = [":PassBaseTdFiles"], ) -cc_library( - name = "LinalgAnalysis", - srcs = glob([ - "lib/Dialect/Linalg/Analysis/*.cpp", - "lib/Dialect/Linalg/Analysis/*.h", - ]), - hdrs = glob([ - "include/mlir/Dialect/Linalg/Analysis/*.h", - ]), - includes = ["include"], - deps = [ - ":BufferizationDialect", - ":FuncDialect", - ":IR", - ":LinalgDialect", - "//llvm:Support", - ], -) - cc_library( name = "LinalgUtils", srcs = glob([ @@ -8422,7 +8403,6 @@ ":DialectUtils", ":FuncDialect", ":IR", - ":LinalgAnalysis", ":LinalgDialect", ":MemRefDialect", ":Pass", @@ -8462,7 +8442,6 @@ ":FuncDialect", ":FuncTransforms", ":IR", - ":LinalgAnalysis", ":LinalgDialect", ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -564,7 +564,6 @@ "//mlir:FuncTransforms", "//mlir:GPUDialect", "//mlir:IR", - "//mlir:LinalgAnalysis", "//mlir:LinalgDialect", "//mlir:LinalgTransforms", "//mlir:LinalgUtils",