diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -18,7 +18,6 @@ namespace mlir { std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); -std::unique_ptr> createLinalgFusionPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -23,12 +23,6 @@ let dependentDialects = ["linalg::LinalgDialect"]; } -def LinalgFusion : FunctionPass<"linalg-fusion"> { - let summary = "Fuse operations in the linalg dialect"; - let constructor = "mlir::createLinalgFusionPass()"; - let dependentDialects = ["linalg::LinalgDialect"]; -} - def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -88,27 +88,22 @@ /// transformation and thus requires the `consumerdIdx`^th operand of `consumer` /// to be a `subview` op (generally obtained by applying the tiling /// transformation). -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. Optional fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, - const LinalgDependenceGraph &graph, - OperationFolder *folder = nullptr); + const LinalgDependenceGraph &graph); /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" /// transformation and thus requires the `consumerdIdx`^th operand of `consumer` /// to be the result of a `subtensor` op (generally obtained by applying the /// tiling transformation). Optional fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, - OperationFolder *folder); + unsigned consumerIdx); /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. -Optional> -fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, OperationFolder *folder = nullptr); +Optional> fuseTensorOps(PatternRewriter &rewriter, + Operation *consumer, + unsigned consumerIdx); /// Returns the linearized list of all shape dimensions in a `linalgOp`. /// Applying the inverse, concatenated loopToOperandRangeMaps to this list @@ -122,17 +117,12 @@ /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the /// concatenated indexing maps to the result of `getShape`. Returns None if /// the bounds computation fails. -Optional> -getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, - OperationFolder *folder = nullptr); +Optional> getLoopRanges(OpBuilder &builder, + LinalgOp linalgOp); /// Returns the values obtained by applying `map` to the list of values. -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. SmallVector applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, ValueRange values, - OperationFolder *folder = nullptr); + AffineMap map, ValueRange values); /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -13,7 +13,6 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -24,7 +23,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" #include "mlir/Support/LLVM.h" -#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" @@ -37,8 +35,6 @@ using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -using folded_std_constant_index = FoldedValueBuilder; - using llvm::dbgs; /// Implements a simple high-level fusion pass on linalg structured operations. @@ -201,8 +197,7 @@ /// 2. Tensor case: `producerIdx` is the index of the tensor in /// `producer.getResults()`. static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx, - OperationFolder *folder = nullptr) { + LinalgOp consumer, unsigned consumerIdx) { Operation *shapeProducingOp = consumer.getShapedOperand(consumerIdx).getDefiningOp(); assert((isa(shapeProducingOp) || @@ -244,9 +239,9 @@ << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto shapeDim = getShapeDefiningLoopRange(producer, i); - loopRanges[i] = Range{folded_std_constant_index(folder, 0), + loopRanges[i] = Range{std_constant_index(0), std_dim(shapeDim.shape, shapeDim.dimension), - folded_std_constant_index(folder, 1)}; + std_constant_index(1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } @@ -396,15 +391,21 @@ return {}; } -Optional mlir::linalg::fuseProducerOfBuffer( - OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, - const LinalgDependenceGraph &graph, OperationFolder *folder) { +Optional +mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + const LinalgDependenceGraph &graph) { Optional fusableDependence = findFusableProducer(consumer, consumerIdx, graph); if (!fusableDependence) return {}; LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + // If producer is already in the same block as consumer, we are done. + if (consumer.getOperation()->getBlock() == + producerOp.getOperation()->getBlock()) + return {}; + Value producerView = fusableDependence->dependentOpView.view; Value consumerView = fusableDependence->indexingView; @@ -427,8 +428,7 @@ assert(producerIdxOpt.hasValue() && "incorrect operand index"); unsigned producerIdx = producerIdxOpt.getValue(); - auto fusedProducer = - fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); + auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); return FusionInfo{producerOp, fusedProducer}; } @@ -459,10 +459,9 @@ } } -Optional -mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, - OperationFolder *folder) { +Optional mlir::linalg::fuseProducerOfTensor(OpBuilder &b, + LinalgOp consumer, + unsigned consumerIdx) { Value inputTensor = consumer.getInput(consumerIdx); LinalgOp producerOp; unsigned producerIdx; @@ -475,13 +474,18 @@ return {}; } + // If producer is already in the same block as consumer, we are done. + if (consumer.getOperation()->getBlock() == + producerOp.getOperation()->getBlock()) + return {}; + // Insert fused `producer` just before `consumer`. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(consumer.getOperation()); ScopedContext scope(b, consumer.getLoc()); LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); LinalgOp fusedProducer = - fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); + fuse(b, producerOp, producerIdx, consumer, consumerIdx); // Replace use. // Canonicalizations are not guaranteed to have happened before constructing @@ -796,72 +800,3 @@ } return llvm::None; } - -static void fuseLinalgOpsGreedily(FuncOp f) { - LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); - - OpBuilder b(f); - OperationFolder folder(f.getContext()); - DenseSet eraseSet; - - // Save original Linalg ops, we only want to make a pass over those. - SmallVector linalgOps; - f.walk([&](LinalgOp op) { - // TODO: support multi-results. - if (op.getOperation()->getNumResults() <= 1) - linalgOps.push_back(op); - }); - - // Tile and Fuse for tensors inputs (TODO: all tensor operands). - for (auto *op : llvm::reverse(linalgOps)) { - LinalgOp linalgOp = cast(op); - for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { - if (en.value().getType().isa()) { - // TODO: LinalgDependenceGraph should be able to update itself. - // The current naive and expensive reconstruction of the graph should be - // removed. - linalg::Aliases aliases; - linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = - fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) { - auto *originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - } - } else { - assert(en.value().getType().isa()); - // Tile and Fuse tensor input (TODO: init_tensors too). - if (en.index() >= linalgOp.getNumInputs()) - continue; - if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) { - auto *originalOp = info->originalProducer.getOperation(); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - // Don't mark for erasure in the tensor case, let DCE handle this. - } - } - } - } - // The `fuseProducerOfBuffer` function performs structural checks and in - // particular that no covering read or write exist between the consumer and - // the producer. As a consequence, the only fusions that may occur preserve - // subsequent dependences and are guaranteed by construction to produce the - // whole view. We may thus erase the producer once it is fused. - for (auto *e : eraseSet) - e->erase(); - - LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); -} - -namespace { -struct LinalgFusionPass : public LinalgFusionBase { - void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } -}; -} // namespace - -std::unique_ptr> mlir::createLinalgFusionPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -177,8 +177,7 @@ static Optional> fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, - PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + PatternRewriter &rewriter) { if (!areTensorOpsFusable(producer, consumer, consumerIdx)) return llvm::None; @@ -440,8 +439,8 @@ /// conditions have been satisfied. static Optional> fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp, - unsigned fusedTensorIndex, PatternRewriter &rewriter, - OperationFolder *folder = nullptr) { + unsigned fusedTensorIndex, + PatternRewriter &rewriter) { assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. @@ -929,7 +928,7 @@ Optional> mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx, OperationFolder *folder) { + unsigned consumerIdx) { if (consumerIdx >= consumer->getNumOperands()) return llvm::None; Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); @@ -942,7 +941,7 @@ return llvm::None; return fuseTensorOpsImpl(cast(producer), cast(consumer), - consumerIdx, rewriter, folder); + consumerIdx, rewriter); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/FoldUtils.h" using namespace mlir; using namespace mlir::linalg; @@ -57,30 +56,27 @@ return llvm::None; } -static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ValueRange operandsRef, - OperationFolder *folder) { +static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ValueRange operandsRef) { SmallVector operands(operandsRef.begin(), operandsRef.end()); fullyComposeAffineMapAndOperands(&map, &operands); canonicalizeMapAndOperands(&map, &operands); - return folder ? folder->create(b, loc, map, operands) - : b.create(loc, map, operands); + return b.createOrFold(loc, map, operands); } SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map, - ValueRange values, - OperationFolder *folder) { + ValueRange values) { SmallVector res; res.reserve(map.getNumResults()); unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); // For each `expr` in `map`, applies the `expr` to the values extracted from // ranges. If the resulting application can be folded into a Value, the - // folding occurs eagerly. Otherwise, an affine.apply operation is emitted. + // folding occurs eagerly. for (auto expr : map.getResults()) { AffineMap map = AffineMap::get(numDims, numSym, expr); - res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, folder)); + res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); } return res; } @@ -159,15 +155,14 @@ return res; } -Optional> -getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) { +Optional> getLoopRanges(OpBuilder &builder, + LinalgOp linalgOp) { SmallVector viewSizes = getShape(builder, linalgOp); AffineMap invertedMap = inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); if (!invertedMap) return {}; - return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes, - folder); + return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes); } /// Specialization to build an scf "for" nest. diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion | FileCheck %s +// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index diff --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir --- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir +++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s #map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #id_2d = affine_map<(d0, d1) -> (d0, d1)> @@ -82,8 +82,11 @@ ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors %i_int = index_cast %i: index to i32 %i_float = sitofp %i_int : i32 to f32 + %j_int = index_cast %j: index to i32 + %j_float = sitofp %j_int : i32 to f32 %ab = addf %a, %b : f32 - %out = addf %ab, %i_float : f32 + %tmp = addf %ab, %i_float : f32 + %out = addf %tmp, %j_float : f32 linalg.yield %out : f32 } %C_X = dim %C, %c0 : memref @@ -115,6 +118,7 @@ // CHECK: [[i_new:%.*]] = addi [[i]], [[I]] : index // CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index // CHECK: {{.*}} = index_cast [[i_new]] : index to i32 +// CHECK: {{.*}} = index_cast [[j_new]] : index to i32 // CHECK: linalg.generic // CHECK: addf @@ -137,10 +141,13 @@ ins(%A, %B: memref, memref) outs(%C : memref) { ^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors + %i_int = index_cast %i: index to i32 + %i_float = sitofp %i_int : i32 to f32 %j_int = index_cast %j: index to i32 %j_float = sitofp %j_int : i32 to f32 %ab = addf %a, %b : f32 - %out = addf %ab, %j_float : f32 + %tmp = addf %ab, %i_float : f32 + %out = addf %tmp, %j_float : f32 linalg.yield %out : f32 } %C_X = dim %C, %c0 : memref @@ -176,8 +183,8 @@ // CHECK-NOT: scf.parallel // CHECK: linalg.indexed_generic // CHECK: ^bb0([[i:%.*]]: index, [[j:%.*]]: index -// CHECK: [[i_new:%.*]] = addi [[i]], [[C0]] : index // CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index +// CHECK: {{.*}} = index_cast [[i]] : index to i32 // CHECK: {{.*}} = index_cast [[j_new]] : index to i32 // CHECK: linalg.generic // CHECK: addf diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s func @f1(%A: memref, %B: memref, @@ -98,6 +98,8 @@ // ----- +// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> + func @f3(%A: memref, %B: memref, %C: memref, @@ -137,9 +139,11 @@ } // CHECK-LABEL: func @f3 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref -// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref -// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D_0:.*]] = dim %[[D]], %[[C0]] : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1]] : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1]] : memref // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { @@ -148,6 +152,8 @@ // ----- +// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> + func @f4(%A: memref, %B: memref, %C: memref, @@ -190,9 +196,11 @@ } // CHECK-LABEL: func @f4 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref -// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref -// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -246,26 +254,24 @@ } // CHECK-LABEL: func @f5 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %c1{{_[0-9]*}} : memref -// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref -// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref -// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { -// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { -// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]] -// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]] -// CHECK-DAG: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]] -// CHECK: dim -// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], %{{.*}}] -// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][%{{.*}}, %[[K]]] -// CHECK-DAG: %[[D_IK_:.*]] = subview %[[D]][%[[I]], %[[K]]] -// CHECK: dim -// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}] -// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}] -// CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}] -// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]] -// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]] -// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]] +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %[[C1:.*]] : memref +// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %[[C0:.*]] : memref +// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref +// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][0, 0]{{.*}} +// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { +// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], 0] +// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], 0] +// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { +// CHECK: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]] +// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} { +// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]] +// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][0, %[[K]]] +// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]] +// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]] +// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]] +// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]] // ----- @@ -390,11 +396,13 @@ } // CHECK-LABEL: func @f7 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[A_0:.*]] = dim %[[A]], %c0{{_[0-9]*}} : memref -// CHECK: %[[A_1:.*]] = dim %[[A]], %c1{{_[0-9]*}} : memref -// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref -// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref -// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: %[[A_0:.*]] = dim %[[A]], %[[C0:.*]] : memref +// CHECK: %[[A_1:.*]] = dim %[[A]], %[[C1:.*]] : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref +// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref // CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]] // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s -// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED +// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s #map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)> #map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)> @@ -41,44 +40,19 @@ // CHECK-SAME: %[[A:[0-9a-z]*]]: tensor // CHECK-SAME: %[[B:[0-9a-z]*]]: tensor // CHECK-SAME: %[[C:[0-9a-z]*]]: tensor -// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor // CHECK: scf.for %[[I:[0-9a-z]*]] +// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor to tensor<2x?xf32> // CHECK-NEXT: scf.for %[[J:[0-9a-z]*]] -// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] -// -// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value. -// CHECK: subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor to tensor -// CHECK: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor to tensor<4x?xf32> -// CHECK: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor to tensor -// -// subtensors of the producing matmul. -// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor to tensor -// CHECK-NEXT: %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor to tensor -// CHECK-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor to tensor -// CHECK-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) init(%[[stC]] : tensor) -> tensor -// CHECK-NEXT: %[[stD2:.*]] = tensor_cast %[[stD]] : tensor to tensor -// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor, tensor<4x?xf32>) init(%[[stF]] : tensor) -> tensor -// CHECK-NEXT: subtensor_insert %[[stG]] - - -// CANONICALIZED-LABEL: func @matmul_tensors( -// CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor -// CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor -// CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor -// CANONICALIZED: %[[C0:.*]] = constant 0 : index -// CANONICALIZED: %[[C1:.*]] = constant 1 : index -// CANONICALIZED: scf.for %[[I:[0-9a-z]*]] -// CANONICALIZED-NEXT: scf.for %[[J:[0-9a-z]*]] -// CANONICALIZED-NEXT: scf.for %[[K:[0-9a-z]*]] -// -// CANONICALIZED: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> -// CANONICALIZED: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> +// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]] +// CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> +// CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> // // subtensors of the producing matmul. -// CANONICALIZED: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor -// CANONICALIZED: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor to tensor<2x?xf32> -// CANONICALIZED-NEXT: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor to tensor -// CANONICALIZED-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor to tensor<2x4xf32> -// CANONICALIZED-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32> -// CANONICALIZED-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> -// CANONICALIZED-NEXT: subtensor_insert %[[stG]] +// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor to tensor +// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor to tensor<2x4xf32> +// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]] diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -13,7 +13,9 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::linalg; @@ -104,10 +106,96 @@ applyFusionPatterns(&getContext(), getFunction()); } +static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { + OpBuilder b(f); + DenseSet eraseSet; + + // Save original Linalg ops, we only want to make a pass over those. + SmallVector linalgOps; + f.walk([&](LinalgOp op) { + // TODO: support multi-results. + if (op.getOperation()->getNumResults() <= 1) + linalgOps.push_back(op); + }); + + // Tile and Fuse for tensors inputs (TODO: all tensor operands). + bool changed = false; + for (auto *op : llvm::reverse(linalgOps)) { + LinalgOp linalgOp = cast(op); + for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { + if (en.value().getType().isa()) { + // TODO: LinalgDependenceGraph should be able to update itself. + // The current naive and expensive reconstruction of the graph should be + // removed. + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalgOps); + if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) { + auto *originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + changed = true; + } + } else { + assert(en.value().getType().isa()); + // Tile and Fuse tensor input (TODO: init_tensors too). + if (en.index() >= linalgOp.getNumInputs()) + continue; + if (auto info = fuseProducerOfTensor(b, op, en.index())) { + auto *originalOp = info->originalProducer.getOperation(); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + // Don't mark for erasure in the tensor case, let DCE handle this. + changed = true; + } + } + } + } + // The `fuseProducerOfBuffer` function performs structural checks and in + // particular that no covering read or write exist between the consumer and + // the producer. As a consequence, the only fusions that may occur preserve + // subsequent dependences and are guaranteed by construction to produce the + // whole view. We may thus erase the producer once it is fused. + for (auto *e : eraseSet) + e->erase(); + + return changed ? success() : failure(); +} + +namespace { +struct TestLinalgGreedyFusion + : public PassWrapper { + void runOnFunction() override { + MLIRContext *context = &getContext(); + OwningRewritePatternList patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + patterns.insert(context); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + while (succeeded(fuseLinalgOpsGreedily(getFunction()))) { + applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); + PassManager pm(context); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + LogicalResult res = pm.run(getFunction().getParentOfType()); + if (failed(res)) + this->signalPassFailure(); + } + } +}; +} // namespace + namespace mlir { void registerTestLinalgFusionTransforms() { PassRegistration testFusionTransformsPass( "test-linalg-fusion-transform-patterns", "Test Linalg fusion transformation patterns by applying them greedily."); } +void registerTestLinalgGreedyFusion() { + PassRegistration testFusionTransformsPass( + "test-linalg-greedy-fusion", + "Test Linalg fusion by applying a greedy test transformation."); +} } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -61,6 +61,7 @@ void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgFusionTransforms(); +void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); @@ -121,6 +122,7 @@ registerTestInterfaces(); registerTestLinalgCodegenStrategy(); registerTestLinalgFusionTransforms(); + registerTestLinalgGreedyFusion(); registerTestLinalgHoisting(); registerTestLinalgTransforms(); registerTestLivenessPass();