diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -63,6 +63,7 @@ using dependence_range = iterator_range; enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; + static StringRef getDependenceTypeStr(DependenceType depType); // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -101,6 +101,10 @@ // Input and Output arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< + "Return one single buffer at position `$i`.", + "Value", "getBuffer", (ins "unsigned":$i) + >, + InterfaceMethod< "Return the number of inputs and outputs, irrespective of their buffer " "or tensor type.", "unsigned", "getNumInputsAndOutputs" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -184,6 +184,10 @@ //==========================================================================// // Input and Output arguments handling. //==========================================================================// + Value getBuffer(unsigned i) { + assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index"); + return this->getOperation()->getOperand(i); + } /// Return the number of inputs and outputs, irrespective of their buffer or /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -24,24 +24,6 @@ using llvm::dbgs; -#ifndef NDEBUG -static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) { - switch (dt) { - case LinalgDependenceGraph::DependenceType::RAW: - return "RAW"; - case LinalgDependenceGraph::DependenceType::RAR: - return "RAR"; - case LinalgDependenceGraph::DependenceType::WAR: - return "WAR"; - case LinalgDependenceGraph::DependenceType::WAW: - return "WAW"; - default: - break; - } - llvm_unreachable("Unexpected DependenceType"); -} -#endif - Value Aliases::find(Value v) { if (v.isa()) return v; @@ -76,6 +58,22 @@ } } +StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) { + switch (depType) { + case LinalgDependenceGraph::DependenceType::RAW: + return "RAW"; + case LinalgDependenceGraph::DependenceType::RAR: + return "RAR"; + case LinalgDependenceGraph::DependenceType::WAR: + return "WAR"; + case LinalgDependenceGraph::DependenceType::WAW: + return "WAW"; + default: + break; + } + llvm_unreachable("Unexpected DependenceType"); +} + LinalgDependenceGraph LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { SmallVector linalgOps; @@ -100,7 +98,7 @@ void LinalgDependenceGraph::addDependenceElem(DependenceType dt, LinalgOpView indexingOpView, LinalgOpView dependentOpView) { - LLVM_DEBUG(dbgs() << "\nAdd dep type " << toStringRef(dt) << ":\t" + LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t" << *indexingOpView.op << " -> " << *dependentOpView.op); dependencesFromGraphs[dt][indexingOpView.op].push_back( LinalgDependenceGraphElem{dependentOpView, indexingOpView.view}); @@ -227,8 +225,8 @@ continue; auto *op = dependence.dependentOpView.op; LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " - << toStringRef(dt) << ": " << *src << " -> " << *op - << " on " << dependence.indexingView); + << getDependenceTypeStr(dt) << ": " << *src << " -> " + << *op << " on " << dependence.indexingView); res.push_back(op); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -157,9 +157,9 @@ } auto subView = dyn_cast_or_null( - consumer.getInput(consumerIdx).getDefiningOp()); - auto slice = - dyn_cast_or_null(consumer.getInput(consumerIdx).getDefiningOp()); + consumer.getBuffer(consumerIdx).getDefiningOp()); + auto slice = dyn_cast_or_null( + consumer.getBuffer(consumerIdx).getDefiningOp()); assert(subView || slice); (void)subView; (void)slice; @@ -274,16 +274,15 @@ return true; } -// Only consider RAW atm. -Optional mlir::linalg::fuseProducerOf( - OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, - const LinalgDependenceGraph &graph, OperationFolder *folder) { +static Optional +fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &graph, OperationFolder *folder, + LinalgDependenceGraph::DependenceType depType) { assert(consumer.hasBufferSemantics() && "expected linalg op with buffer semantics"); LLVM_DEBUG(dbgs() << "\nStart examining consumer: " << *consumer.getOperation()); - for (auto dependence : graph.getDependencesInto( - consumer, LinalgDependenceGraph::DependenceType::RAW)) { + for (auto dependence : graph.getDependencesInto(consumer, depType)) { LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" << *dependence.dependentOpView.op << "\n"); auto producer = cast(dependence.dependentOpView.op); @@ -294,7 +293,7 @@ // Check that the dependence is indeed on the input `consumerIdx` view. auto consumedView = dependence.indexingView; - if (consumer.getInput(consumerIdx) != consumedView) + if (consumer.getBuffer(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks @@ -302,9 +301,10 @@ auto producedView = dependence.dependentOpView.view; auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. - LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() - << " view: " << producedView - << " output index: " << producerIdx); + LLVM_DEBUG(dbgs() << "\n" + << LinalgDependenceGraph::getDependenceTypeStr(depType) + << "producer: " << *producer.getOperation() << " view: " + << producedView << " output index: " << producerIdx); // Must be a subview or a slice to guarantee there are loops we can fuse // into. @@ -332,6 +332,22 @@ return llvm::None; } +// Only consider RAW and WAW atm. +Optional mlir::linalg::fuseProducerOf( + OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &graph, OperationFolder *folder) { + SmallVector deps = { + LinalgDependenceGraph::DependenceType::RAW, + LinalgDependenceGraph::DependenceType::WAW, + }; + for (auto dep : deps) { + if (auto res = + fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) + return res; + } + return llvm::None; +} + /// Checks if two Generic ops are fusible, when one is a producer and another is /// a consumer (with the result of the producer being the `consumerIdx` operand /// of the consumer). @@ -498,7 +514,8 @@ // The current naive and expensive reconstruction of the graph should be // removed. for (auto *op : llvm::reverse(linalgOps)) { - for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { + for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); + id < e; ++id) { linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -41,12 +41,11 @@ } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// No RAW dependences, the pass does not fuse RAR atm. -// CHECK: linalg.matmul // CHECK: loop.for // CHECK: loop.for // CHECK: loop.for // CHECK: linalg.matmul +// CHECK: linalg.matmul // ----- @@ -334,15 +333,13 @@ } // CHECK-LABEL: func @f6 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// Cannot fuse C due to interleaved read of C that would be bypassed. -// Cannot fuse E (WAW). -// CHECK: linalg.matmul -// CHECK: linalg.matmul +// Fuse the producer of E (WAW) then the producer of C (WAR). // CHECK: loop.for // CHECK: loop.for // CHECK: loop.for // CHECK: linalg.matmul -// CHECK-NOT: linalg.matmul +// CHECK: linalg.matmul +// CHECK: linalg.matmul // ----- @@ -785,3 +782,53 @@ // CHECK: linalg.generic // CHECK: exp // CHECK: linalg.yield + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> +#map2 = affine_map<()[s0] -> (s0 + 3)> + +func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<1x4x5x1xf32>) { + %cst = constant 0.000000e+00 : f32 + linalg.fill(%arg2, %cst) : memref<1x4x5x1xf32>, f32 + + %c4 = constant 4 : index + %c1 = constant 1 : index + %c0 = constant 0 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %4 = dim %arg1, 0 : memref<2x3x1x1xf32> + %5 = dim %arg1, 1 : memref<2x3x1x1xf32> + %6 = dim %arg0, 0 : memref<1x4x5x1xf32> + %7 = dim %arg0, 1 : memref<1x4x5x1xf32> + %8 = dim %arg0, 3 : memref<1x4x5x1xf32> + %9 = dim %arg2, 0 : memref<1x4x5x1xf32> + %10 = dim %arg2, 1 : memref<1x4x5x1xf32> + %11 = dim %arg2, 2 : memref<1x4x5x1xf32> + %12 = dim %arg2, 3 : memref<1x4x5x1xf32> + %13 = linalg.range %c0 : %6 : %c2 : !linalg.range + %14 = linalg.range %c0 : %10 : %c3 : !linalg.range + loop.for %arg3 = %c0 to %6 step %c2 { + loop.for %arg4 = %c0 to %10 step %c3 { + %15 = affine.min #map0(%c2, %c1, %arg3) + %16 = affine.apply #map2()[%7] + %17 = affine.min #map0(%16, %c4, %arg4) + %18 = dim %arg0, 2 : memref<1x4x5x1xf32> + %19 = dim %arg0, 3 : memref<1x4x5x1xf32> + %20 = subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref + %21 = affine.min #map0(%c2, %c1, %arg3) + %22 = affine.min #map0(%c3, %c4, %arg4) + %23 = dim %arg2, 2 : memref<1x4x5x1xf32> + %24 = dim %arg2, 3 : memref<1x4x5x1xf32> + %25 = subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref + linalg.conv(%arg1, %20, %25) {dilations = [1, 1], strides = [1, 1]} : memref<2x3x1x1xf32>, memref, memref + } + } + return +} +// CHECK-LABEL: func @fill_and_conv +// CHECK: loop.for +// CHECK: loop.for +// CHECK: linalg.fill +// CHECK: linalg.conv