diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -45,7 +45,7 @@ public: struct LinalgOpView { Operation *op; - Value view; + unsigned operandIndex; }; struct LinalgDependenceGraphElem { // dependentOpView may be either: @@ -55,7 +55,7 @@ // View in the op that is used to index in the graph: // 1. src in the case of dependencesFromDstGraphs. // 2. dst in the case of dependencesIntoGraphs. - Value indexingView; + LinalgOpView indexingOpView; }; using LinalgDependences = SmallVector; using DependenceGraph = DenseMap; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -555,7 +555,7 @@ >, InterfaceMethod< /*desc=*/[{ - Return the position of the shaped operand in the operand list. + Return the first position of the shaped operand in the operand list. }], /*retTy=*/"Optional", /*methodName=*/"getIndexOfShapedOperand", @@ -573,6 +573,67 @@ return llvm::None; }] >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand index given the input index. Returns None + of the input index is invalid. + }], + /*retTy=*/"Optional", + /*methodName=*/"getOperandIndexForInputIndex", + /*args=*/(ins "unsigned":$input_index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (input_index >= $_op.getNumInputs()) + return llvm::None; + return input_index; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the operand index given the output index. Returns None + of the output index is invalid. + }], + /*retTy=*/"Optional", + /*methodName=*/"getOperandIndexForOutputIndex", + /*args=*/(ins "unsigned":$output_index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (output_index >= $_op.getNumOutputs()) + return llvm::None; + return output_index + $_op.getNumInputs(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the input index given the operand index. Return None + if the operand index doesnt corresponding to an input. + }], + /*retTy=*/"Optional", + /*methodName=*/"getInputIndex", + /*args=*/(ins "unsigned":$operand_index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (operand_index >= $_op.getNumInputs()) + return llvm::None; + return operand_index; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the output index given the operand index. Return None + if the operand index doesnt corresponding to an output. + }], + /*retTy=*/"Optional", + /*methodName=*/"getOutputIndex", + /*args=*/(ins "unsigned":$operand_index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (operand_index < $_op.getNumInputs() || + operand_index >= $_op.getNumInputs() + $_op.getNumOutputs()) + return llvm::None; + return operand_index - $_op.getNumInputs(); + }] + >, //===------------------------------------------------------------------===// // Other interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -15,6 +15,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Bufferize.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallSet.h" namespace mlir { class BufferizeTypeConverter; @@ -429,12 +430,10 @@ }; struct LinalgFusionOptions { - /// Optional list of operands indices to use for fusion. When unspecified, - /// only one fusion is done, i.e., the pattern returns after the first fusion. - Optional> indicesToFuse = None; + /// List of operands indices to use for fusion. + llvm::SmallSet indicesToFuse = {}; LinalgFusionOptions &setIndicesToFuse(ArrayRef operands) { - indicesToFuse = DenseSet(); - indicesToFuse->insert(operands.begin(), operands.end()); + indicesToFuse.insert(operands.begin(), operands.end()); return *this; } }; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -323,6 +323,9 @@ /// ``` AffineMap concatAffineMaps(ArrayRef maps); +AffineMap getProjectedMap(AffineMap map, + ArrayRef projectedDimensions); + inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { map.print(os); return os; diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -108,12 +108,14 @@ void LinalgDependenceGraph::addDependenceElem(DependenceType dt, LinalgOpView indexingOpView, LinalgOpView dependentOpView) { - LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t" - << *indexingOpView.op << " -> " << *dependentOpView.op); + LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" + << *indexingOpView.op << ", " << indexingOpView.operandIndex + << ") -> \n\t\t(" << *dependentOpView.op << ", " + << dependentOpView.operandIndex << ")"); dependencesFromGraphs[dt][indexingOpView.op].push_back( - LinalgDependenceGraphElem{dependentOpView, indexingOpView.view}); + LinalgDependenceGraphElem{dependentOpView, indexingOpView}); dependencesIntoGraphs[dt][dependentOpView.op].push_back( - LinalgDependenceGraphElem{indexingOpView, dependentOpView.view}); + LinalgDependenceGraphElem{indexingOpView, dependentOpView}); } LinalgDependenceGraph::dependence_range @@ -147,39 +149,55 @@ } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto srcView : src.getOutputBuffers()) { // W + for (auto srcView : llvm::enumerate(src.getOutputBuffers())) { // W + unsigned srcIndex = + src.getOperandIndexForOutputIndex(srcView.index()).getValue(); // RAW graph - for (auto dstView : dst.getInputBuffers()) { // R - if (aliases.alias(srcView, dstView)) { // if alias, fill RAW + for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R + if (aliases.alias(srcView.value(), + dstView.value())) { // if alias, fill RAW + unsigned dstIndex = + dst.getOperandIndexForInputIndex(dstView.index()).getValue(); addDependenceElem(DependenceType::RAW, - LinalgOpView{src.getOperation(), srcView}, - LinalgOpView{dst.getOperation(), dstView}); + LinalgOpView{src.getOperation(), srcIndex}, + LinalgOpView{dst.getOperation(), dstIndex}); } } // WAW graph - for (auto dstView : dst.getOutputBuffers()) { // W - if (aliases.alias(srcView, dstView)) { // if alias, fill WAW + for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W + if (aliases.alias(srcView.value(), + dstView.value())) { // if alias, fill WAW + unsigned dstIndex = + dst.getOperandIndexForOutputIndex(dstView.index()).getValue(); addDependenceElem(DependenceType::WAW, - LinalgOpView{src.getOperation(), srcView}, - LinalgOpView{dst.getOperation(), dstView}); + LinalgOpView{src.getOperation(), srcIndex}, + LinalgOpView{dst.getOperation(), dstIndex}); } } } - for (auto srcView : src.getInputBuffers()) { // R + for (auto srcView : llvm::enumerate(src.getInputBuffers())) { // R + unsigned srcIndex = + src.getOperandIndexForInputIndex(srcView.index()).getValue(); // RAR graph - for (auto dstView : dst.getInputBuffers()) { // R - if (aliases.alias(srcView, dstView)) { // if alias, fill RAR + for (auto dstView : llvm::enumerate(dst.getInputBuffers())) { // R + if (aliases.alias(srcView.value(), + dstView.value())) { // if alias, fill RAR + unsigned dstIndex = + dst.getOperandIndexForInputIndex(dstView.index()).getValue(); addDependenceElem(DependenceType::RAR, - LinalgOpView{src.getOperation(), srcView}, - LinalgOpView{dst.getOperation(), dstView}); + LinalgOpView{src.getOperation(), srcIndex}, + LinalgOpView{dst.getOperation(), dstIndex}); } } // WAR graph - for (auto dstView : dst.getOutputBuffers()) { // W - if (aliases.alias(srcView, dstView)) { // if alias, fill WAR + for (auto dstView : llvm::enumerate(dst.getOutputBuffers())) { // W + if (aliases.alias(srcView.value(), + dstView.value())) { // if alias, fill WAR + unsigned dstIndex = + dst.getOperandIndexForOutputIndex(dstView.index()).getValue(); addDependenceElem(DependenceType::WAR, - LinalgOpView{src.getOperation(), srcView}, - LinalgOpView{dst.getOperation(), dstView}); + LinalgOpView{src.getOperation(), srcIndex}, + LinalgOpView{dst.getOperation(), dstIndex}); } } } @@ -227,12 +245,16 @@ // Skip if not interleaved. if (interimPos >= dstPos || interimPos <= srcPos) continue; - if (view && !aliases.alias(view, dependence.indexingView)) + linalg::LinalgOp consumer = + cast(dependence.indexingOpView.op); + Value consumerView = + consumer.getShapedOperand(dependence.indexingOpView.operandIndex); + if (view && !aliases.alias(view, consumerView)) continue; auto *op = dependence.dependentOpView.op; LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << getDependenceTypeStr(dt) << ": " << *src << " -> " - << *op << " on " << dependence.indexingView); + << *op << " on " << consumerView); res.push_back(op); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -24,10 +24,12 @@ #include "mlir/IR/Dominance.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/MapVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include + #define DEBUG_TYPE "linalg-fusion" using namespace mlir; @@ -95,8 +97,8 @@ for (auto en : llvm::enumerate(op.getShapedOperands())) { unsigned shapedOperandIdx = en.index(); AffineMap map = op.getIndexingMap(shapedOperandIdx); - LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx - << " with indexingMap: " << map << "\n"); + LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx + << " with indexingMap: " << map << "\n"); SmallVector offsets, sizes, strides; inferShapeComponents(map, loopRanges, offsets, sizes, strides); Value shape = en.value(); @@ -169,16 +171,18 @@ for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx].cast().getValue(); - LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); - LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "getShapeDefiningLoopRange map: " << map << "\n"); Value shape = en.value(); SmallVector shapeRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { if (loopDepth == en2.value().cast().getPosition()) { - LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: " - << loopDepth << "\n"); - LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape - << "\n"); + LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " + << loopDepth << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "getShapeDefiningLoopRange shape: " << shape << "\n"); return ShapeDimension{shape, static_cast(en2.index())}; } } @@ -209,8 +213,8 @@ // dimension. // TODO: extend this with range inference. AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); @@ -258,7 +262,7 @@ assert(consumer.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (producer.getNumOutputs() != 1) { - LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); + LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); return false; } // Only fuse when the producer block dominates. @@ -266,7 +270,7 @@ if (!dom.dominates(producer.getOperation()->getBlock(), consumer.getOperation()->getBlock())) { LLVM_DEBUG( - dbgs() + llvm::dbgs() << "\nNot structurally fusable (producer block does not dominate)"); return false; } @@ -284,14 +288,14 @@ // Make some simple structural checks that alleviate the need for more // complex analyses. if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { - LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" - << *producer.getOperation()); + LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); return false; } // Check for any interleaved write to consumedView. if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { - LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" - << *producer.getOperation()); + LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); return false; } return true; @@ -309,8 +313,9 @@ // Check for any fusion-preventing dependence to any shape read/written that // would violate dependences. if (!graph.findCoveringDependences(producer, consumer).empty()) { - LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" - << *producer.getOperation()); + LLVM_DEBUG(llvm::dbgs() + << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); return false; } if (auto convOp = dyn_cast(producer.getOperation())) { @@ -360,26 +365,33 @@ LinalgDependenceGraph::DependenceType::RAW, LinalgDependenceGraph::DependenceType::WAW, }) { - for (auto dependence : - dependenceGraph.getDependencesInto(consumer, depType)) { + for (auto dependence : llvm::make_filter_range( + dependenceGraph.getDependencesInto(consumer, depType), + [consumerIdx]( + LinalgDependenceGraph::LinalgDependenceGraphElem elem) { + return elem.indexingOpView.operandIndex == consumerIdx; + })) { auto producer = cast(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto consumedView = dependence.indexingView; + auto consumedView = + consumer.getBuffer(dependence.indexingOpView.operandIndex); if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also // checks whether it is a strict subview of the producer view. - auto producedView = dependence.dependentOpView.view; - auto producerIdx = - producer.getIndexOfOutputBuffer(producedView).getValue(); - // `consumerIdx` and `producerIdx` exist by construction. - LLVM_DEBUG(dbgs() << "\n" - << LinalgDependenceGraph::getDependenceTypeStr(depType) - << "producer: " << *producer.getOperation() << " view: " - << producedView << " output index: " << producerIdx); - (void)producerIdx; + auto producedView = + producer.getBuffer(dependence.dependentOpView.operandIndex); + LLVM_DEBUG(llvm::dbgs() + << "\n" + << LinalgDependenceGraph::getDependenceTypeStr(depType) + << "producer: " << *producer.getOperation() + << " view: " << producedView << " output index: " + << dependence.dependentOpView.operandIndex - + producer.getNumInputs() + << "\n"); + (void)producedView; // Simple fusability checks. if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) @@ -406,15 +418,16 @@ producerOp.getOperation()->getBlock()) return {}; - Value producerView = fusableDependence->dependentOpView.view; - Value consumerView = fusableDependence->indexingView; + unsigned producerIdx = fusableDependence->dependentOpView.operandIndex - + producerOp.getNumInputs(); + Value consumerView = consumer.getShapedOperand(consumerIdx); // Must be a subview or a slice to guarantee there are loops we can fuse // into. auto subView = consumerView.getDefiningOp(); auto slice = consumerView.getDefiningOp(); if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); + LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)"); return {}; } @@ -422,11 +435,7 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(consumer.getOperation()); ScopedContext scope(b, consumer.getLoc()); - LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); - Optional producerIdxOpt = - producerOp.getIndexOfOutputBuffer(producerView); - assert(producerIdxOpt.hasValue() && "incorrect operand index"); - unsigned producerIdx = producerIdxOpt.getValue(); + LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); return FusionInfo{producerOp, fusedProducer}; @@ -470,7 +479,7 @@ // Must be a subtensor to guarantee there are loops we can fuse into. auto subTensor = inputTensor.getDefiningOp(); if (!subTensor || !producerOp) { - LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)"); + LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)"); return {}; } @@ -483,7 +492,7 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(consumer.getOperation()); ScopedContext scope(b, consumer.getLoc()); - LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); LinalgOp fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); @@ -501,6 +510,21 @@ return FusionInfo{producerOp, fusedProducer}; } +/// Prune all dimensions that are of reduction iterator type from `map`. +static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, + AffineMap map) { + SmallVector projectedDims; + for (auto attr : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(attr.value())) + projectedDims.push_back(attr.index()); + } + return getProjectedMap(map, projectedDims); +} + +using FusableOpDependencesTy = llvm::MapVector< + Operation *, + SmallVector>; + /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a /// @@ -508,12 +532,58 @@ /// /// if the producer of %a needs to be fused with this op, only the `i` loop of /// the matmul can be tiled while fusing. If producer of %a, and %b are to be -/// fused, then no loops can be tiled while fusing. -static DenseSet collectTileAndFuseLoops( - LinalgOp op, ArrayRef - fusableDependences) { - // 1. Only parallel loops can be used for tile + fuse. Find the number of - // common outer parallel loops between the op and its producers being fused. +/// fused, then no loops can be tiled while fusing. The conditions used are: +/// 1. Only parallel loops can be used for tile + fuse. Find the number of +/// common outer parallel loops between the op and its producers being fused. +/// 2. Of the parallel loops only some can be fused. Only those loops can be +/// fused such where the fusable loops iteration space only touches one tile +/// of the fused operation. This is because the producer (which is writing +/// the fused subview) has update semantics. To compute this, +/// a. Find the mapping from iterations in the consumer that write to the +/// same location as the iterations in the producer. To do so use +/// - indexing map of the fused view in the consumer : consumerIndexMap +/// - indexing map of the fused view in the producer : producerIndexMap +/// consumerLoopToProducerLoop = +/// inverse(producerIndexMap).compose(consumerIndexMap) +/// +/// Since an inverse computation is needed, we need to consider the projection +/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops +/// are the dimensions of the consumerLoopToProducerLoop map that correspond to +/// parallel loops and appear in the result of the map +/// +/// Example 1: +/// linalg.fill(%c, %cst) +/// linalg.matmul ins(%a, %b) outs(%c) +/// Number of parallel loops : 2 +/// producerIndexMap = affine_map<(i, j) ->(i , j)> +/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> +/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> +/// Fused dimensions : i, j +/// +/// Example 2: +/// linalg.matmul ins(%a, %b) outs(%c) +/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... +/// iterator_types = ["parallel", "parallel"]} +/// ins(%c) ... +/// +/// Number of parallel loops = 2: +/// producerIndexMap (projected to parallel loops) = +/// affine_map<(i, j) -> (i, j)> +/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> +/// Fused dimensions : i, j +/// +/// Example 3: +/// linalg.copy(%s, %b) +/// linalg.matmul ins(%a, %b) outs(%c) +/// +/// Number of parallel loops = 2 +/// produceIndexMap : affine_map<(i, j) -> (i, j)> +/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> +/// submap with only parallel loops = affine_map<(i, j) -> (j)> +/// Fused dimensions : j +static std::set +collectTileAndFuseLoops(LinalgOp op, + const FusableOpDependencesTy &fusableDependences) { auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -524,135 +594,149 @@ .size(); }; + LLVM_DEBUG({ + llvm::dbgs() << "Op : "; + op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + size_t numOuterParallelLoops = getNumOuterParallelLoops(op); for (auto dependence : fusableDependences) { + linalg::LinalgOp producer = cast(dependence.first); numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast( - dependence.dependentOpView.op))); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); } - // Need to compute what tiled loops can be "fused". Given the precondition - // that all indexing map for the producer view is a projected permutation, we - // can assert that the producer iterates over the dimensions of the "fused - // view" only once. To be used a fused loop the producer should use this loop - // to access the fused view. For example, consider - // - // ``` - // linalg.add ins(%a, %b) outs(%c) - // linalg.matmul ins(%d, %c) outs(%e) - // ``` - // - // if `linalg.add` has the semantics of `c = a + b`, then the following - // tile+fuse code is correct. - // - // ``` - // for j ... += TSj - // %sa = subview %a[0, %j][...] - // %sb = subview %b[0, %j][...] - // %sc = subview %c[0, %j][...] - // %sd = subview %d[0, 0][...] - // %se = subview %e[0, %j][...] - // linalg.add ins(%sa, %sb) outs(%sc) - // linalg.matmul ins(%sd, %sc) outs(%se) - // ``` - // - // On the other hand tiling along i would be incorrect - // - // ``` - // for %i .. += TSi - // %sa = subview %a[%i, 0][...] - // %sb = subview %b[%i, 0][...] - // %sc = subview %c[%i, 0][...] - // %sc2 = subview %c[0, 0][...] - // %sd = subview %d[%i, 0][...] - // %se = subview %e[%i, 0][...] - // linalg.add ins(%sa, %sb) outs(%sc) - // linalg.matmul ins(%sd, %sc2) outs(%se) - // ``` - // - // The write to the subview `%sc` in `linalg.add` is performed after the read - // from it using `%sc2` violating the RAW dependence of the original code. To - // find such loops indexing map of the fused view in the consumer op is - // used. For the above example, this indexing map is - // - // affine_map<(d0, d1, d2) -> (d2, d1)> - // - // Since d0 is not in the result expressions of this map, it is not treated as - // tile + fuse loop, (but d1 is). - // - // TODO: The above is probably restrictive and there might be a generalization - // of these that might allow for more fusion opportunities. Explore based on - // needs. - SmallVector, 1> commonTilableLoops; + std::set fusableLoops; + auto range = llvm::seq(0, numOuterParallelLoops); + fusableLoops.insert(range.begin(), range.end()); for (auto dependence : fusableDependences) { - unsigned consumerIdx = - op.getIndexOfShapedOperand(dependence.indexingView).getValue(); - AffineMap consumerAccess = op.getIndexingMap(consumerIdx); - // Previously asserted that the consumerAccess map is a projected - // permutation, so all results are known to be AffineDimExprs. To remove - // this restriction walk the expression to find which dimensions of the - // consumer loop appear in the `consumerAccess`. - DenseSet positions; - for (auto expr : consumerAccess.getResults()) - positions.insert(expr.cast().getPosition()); - commonTilableLoops.emplace_back(std::move(positions)); + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + linalg::LinalgOp producer = cast(dependence.first); + + assert(!dependence.second.empty() && + "unexpected producer but not dependences"); + AffineMap producerIndexingMap = producer.getIndexingMap( + dependence.second.front().dependentOpView.operandIndex); + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + producer.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return {}; + + AffineMap consumerIndexingMap = op.getIndexingMap( + dependence.second.front().indexingOpView.operandIndex); + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return {}; + + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = + inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return {}; + + AffineMap consumerLoopToProducerLoop = + invProducerIndexMap.compose(consumerIndexingMap); + + LLVM_DEBUG({ + llvm::dbgs() << "\t consumerLoopToProducerLoop : "; + consumerLoopToProducerLoop.print(llvm::dbgs()); + }); + + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { + AffineDimExpr dimExpr = expr.dyn_cast(); + if (!dimExpr) + continue; + unsigned position = dimExpr.getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); + } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } - // 2. Of the outer parallel loops, only those loops can be tiled + fused as - // computed above for all the fused dependences can be used to tile and fuse. - DenseSet tilableParallelLoops; - for (auto index : llvm::seq(0, numOuterParallelLoops)) { - if (llvm::all_of(commonTilableLoops, - [&](const DenseSet &tilableLoops) { - return tilableLoops.count(index); - })) - tilableParallelLoops.insert(index); - } - return tilableParallelLoops; + return fusableLoops; } /// Find all dependences that are to be fusable. -static Optional< - SmallVector> +static FusableOpDependencesTy findAllFusableDependences(LinalgOp op, const LinalgDependenceGraph &dependenceGraph, const LinalgFusionOptions &fusionOptions) { - SmallVector - fusableDependences; - for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { - if (fusionOptions.indicesToFuse && - !fusionOptions.indicesToFuse->count(operand.index())) - continue; - Optional - fusableDependence = - findFusableProducer(op, operand.index(), dependenceGraph); + FusableOpDependencesTy fusableDependences; + // TODO: Currently fusion would not be legal if the fusable dependence is to + // the same producer but different indexing map in the consumer. Fix this, but + // in the meanwhile disallow such a fusion. + DenseMap fusedProducerIndexingMap; + for (auto operandIndex : fusionOptions.indicesToFuse) { + auto fusableDependence = + findFusableProducer(op, operandIndex, dependenceGraph); if (!fusableDependence) - continue; + return FusableOpDependencesTy{}; + LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; + } // Make sure that the indexing map of the view used for fusion in the // producer is a projected permutation. - LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); - Value producerView = fusableDependence->dependentOpView.view; - unsigned producerIdx = - producerOp.getIndexOfOutputBuffer(producerView).getValue(); - AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx); + unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); if (!producerMap.isProjectedPermutation()) { - op.emitError("unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operand.index(); - return llvm::None; + op.emitRemark("unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandIndex; + return FusableOpDependencesTy{}; } - Value consumerView = fusableDependence->indexingView; - unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue(); - if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { - op.emitError( + + unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( "unhandled case where indexing map for fused view in the consumer is " - "not a projected permuration while fusing at index ") - << operand.index(); - return llvm::None; + "not a projected permutation while fusing at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } + + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark("unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; } - fusableDependences.push_back(*fusableDependence); - if (!fusionOptions.indicesToFuse) - break; + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + + fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); } return fusableDependences; } @@ -682,13 +766,10 @@ ScopedContext scope(rewriter, op.getLoc()); // Find all the producers. - Optional> - fusableDependencesOpt = - findAllFusableDependences(op, dependenceGraph, fusionOptions); - if (!fusableDependencesOpt) + FusableOpDependencesTy fusableDependences = + findAllFusableDependences(op, dependenceGraph, fusionOptions); + if (fusableDependences.empty()) return llvm::None; - ArrayRef fusableDependences( - *fusableDependencesOpt); // Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of @@ -704,12 +785,12 @@ TiledAndFusedLinalgOps ret; // Find the loops that can be tiled and fused. - DenseSet tileFuseLoops = + std::set tileFuseLoops = collectTileAndFuseLoops(op, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. - if (fusableDependences.empty() || tileFuseLoops.empty()) { + if (tileFuseLoops.empty()) { return llvm::None; } @@ -752,15 +833,15 @@ rewriter.setInsertionPoint(ret.op); // Fuse the operands. - for (auto producer : enumerate(fusableDependences)) { - LinalgOp producerOp = cast(producer.value().dependentOpView.op); + for (auto dependence : fusableDependences) { + LinalgOp producerOp = cast(dependence.first); unsigned producerIdx = - producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view) - .getValue(); + dependence.second.front().dependentOpView.operandIndex; unsigned consumerIdx = - op.getIndexOfShapedOperand(producer.value().indexingView).getValue(); - LinalgOp fusedOp = - fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); + dependence.second.front().indexingOpView.operandIndex; + LinalgOp fusedOp = fuse(rewriter, producerOp, + producerOp.getOutputIndex(producerIdx).getValue(), + ret.op, consumerIdx); ret.fusedProducers.push_back(fusedOp); ret.originalProducers.push_back(producerOp); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" @@ -450,6 +451,22 @@ maps.front().getContext()); } +AffineMap mlir::getProjectedMap(AffineMap map, + ArrayRef projectedDimensions) { + DenseSet projectedDims(projectedDimensions.begin(), + projectedDimensions.end()); + MLIRContext *context = map.getContext(); + SmallVector resultExprs; + for (auto dim : enumerate(llvm::seq(0, map.getNumDims()))) { + if (!projectedDims.count(dim.value())) + resultExprs.push_back(getAffineDimExpr(dim.index(), context)); + else + resultExprs.push_back(getAffineConstantExpr(0, context)); + } + return map.compose(AffineMap::get( + map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context)); +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s module { func @basic_fusion(%arg0: memref, %arg1: memref, @@ -295,3 +295,121 @@ // CHECK: } // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" + +// ----- + +module { + func @matmul_plus_matmul(%arg0: memref, %arg1: memref, + %arg2: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg2, %c0 : memref + %1 = dim %arg2, %c1 : memref + %2 = alloc(%0, %1) : memref + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%2 : memref) + linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "transpose_fusion"} + ins(%2, %2 : memref, memref) + outs(%arg2 : memref) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %3 = addf %arg3, %arg4 : f32 + linalg.yield %3 : f32 + } + return + } +} +// CHECK: func @matmul_plus_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK: %[[T2:.+]] = alloc(%{{.*}}, %{{.*}}) : memref +// CHECK: linalg.matmul +// CHECK-SAME: after_transpose_fusion_original +// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]]) +// CHECK: %[[T5:.+]] = subview %[[T2]][%[[ARG3]], %[[ARG4]]] +// CHECK: %[[T6:.+]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] +// CHECK: %[[T8:.+]] = subview %[[ARG0]][%[[ARG3]], 0] +// CHECK: %[[T9:.+]] = subview %[[ARG1]][0, %[[ARG4]]] +// CHECK: linalg.matmul +// CHECK-SAME: after_transpose_fusion_producer +// CHECK-SAME: ins(%[[T8]], %[[T9]] +// CHECK-SAME: outs(%[[T5]] +// CHECK-NOT: linalg.matmul +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[T5]], %[[T5]] +// CHECK-SAME: outs(%[[T6]] +// CHECK-SAME: after_transpose_fusion + +// ----- + +module { + func @matmul_plus_transpose_matmul(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg2, %c0 : memref + %1 = dim %arg2, %c1 : memref + %2 = alloc(%0, %1) : memref + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%2 : memref) + // expected-remark @+1 {{unhandled fusion to the same producer but with different indexing maps}} + linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "transpose_fusion"} + ins(%2, %2 : memref, memref) + outs(%arg2 : memref) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %3 = addf %arg3, %arg4 : f32 + linalg.yield %3 : f32 + } + return + } +} + +// ----- + +#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)> +#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)> +#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)> +#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +module { + func @basic_no_fusion(%arg0: memref, %arg1: memref, + %arg2: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c32 = constant 32 : index + %c64 = constant 64 : index + %c16 = constant 16 : index + %cst = constant 0.000000e+00 : f32 + linalg.fill(%arg2, %cst) : memref, f32 + %0 = dim %arg0, %c0 : memref + %1 = dim %arg1, %c1 : memref + %2 = dim %arg0, %c1 : memref + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) { + scf.for %arg5 = %c0 to %2 step %c16 { + %3 = affine.min #map0(%arg3)[%0] + %4 = affine.min #map1(%arg4)[%1] + %5 = affine.min #map2(%arg5)[%2] + %6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref to memref + %7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref to memref + %8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref to memref + // expected-remark @+1 {{unhandled fusion of ops in different basic blocks}} + linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} + ins(%6, %7 : memref, memref) + outs(%8 : memref) + } + scf.yield + } + return + } +} diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -43,7 +43,7 @@ LinalgTilingOptions() .setTileSizes({32, 64, 16}) .setLoopType(LinalgTilingLoopType::ParallelLoops), - LinalgFusionOptions(), + LinalgFusionOptions().setIndicesToFuse({2}), LinalgMarker(Identifier::get("basic_fusion", context), Identifier::get("after_basic_fusion", context)), LinalgMarker(ArrayRef(), @@ -91,6 +91,19 @@ LinalgMarker( ArrayRef(), Identifier::get("after_two_operand_fusion_original", context))); + + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions().setTileSizes({32, 64}).setLoopType( + LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions().setIndicesToFuse({0, 1}), + LinalgMarker(Identifier::get("transpose_fusion", context), + Identifier::get("after_transpose_fusion", context)), + LinalgMarker(ArrayRef(), + Identifier::get("after_transpose_fusion_producer", context)), + LinalgMarker( + ArrayRef(), + Identifier::get("after_transpose_fusion_original", context))); } static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {