diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -377,9 +378,13 @@ return true; } +/// For `consumer` with buffer semantics, find the Linalg operation on buffers +/// that is the last writer for the subview at `consumerIdx`. For the now the +/// fusable dependence is returned as an instance of the +/// `LinalgDependenceGraphElem`. static Optional -findFusableProducer(LinalgOp consumer, unsigned consumerIdx, - const LinalgDependenceGraph &dependenceGraph) { +findFusableProducerForBufferOp(LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &dependenceGraph) { // Only consider RAW and WAW atm. for (auto depType : { LinalgDependenceGraph::DependenceType::RAW, @@ -428,7 +433,7 @@ unsigned consumerIdx, const LinalgDependenceGraph &graph) { Optional fusableDependence = - findFusableProducer(consumer, consumerIdx, graph); + findFusableProducerForBufferOp(consumer, consumerIdx, graph); if (!fusableDependence) return {}; @@ -464,10 +469,9 @@ /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp static void getProducerOfTensor(Value tensor, LinalgOp &producer, unsigned &outputIndex) { - if (!tensor.getType().isa()) - return; - while (true) { + if (!tensor.getType().isa()) + return; if (auto linalgOp = tensor.getDefiningOp()) { producer = linalgOp; outputIndex = tensor.cast().getResultNumber(); @@ -478,8 +482,10 @@ continue; } if (auto blockArg = tensor.dyn_cast()) { - if (auto forOp = blockArg.getDefiningOp()) { - tensor = forOp.getResult(blockArg.getArgNumber()); + if (auto forOp = + dyn_cast(blockArg.getOwner()->getParentOp())) { + tensor = *(forOp.getIterOperands().begin() + + (blockArg.getArgNumber() - 1)); // account for iv. continue; } } @@ -490,7 +496,10 @@ Optional mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx) { - Value inputTensor = consumer.getInput(consumerIdx); + Value inputTensor = + (consumerIdx < consumer.getNumInputs()) + ? consumer.getInput(consumerIdx) + : consumer.getInitTensor(consumerIdx - consumer.getNumInputs()); LinalgOp producerOp; unsigned producerIdx; getProducerOfTensor(inputTensor, producerOp, producerIdx); @@ -722,6 +731,39 @@ return fusableLoops; } +/// For `consumer` with tensor semantics, find the Linalg operation on tensors +/// producer the operand at position `consumerIdx`. This is a simple use-def +/// chain using the SSA value, but returned as an element of the +/// `LinalgDependenceGraphElem` to use the same analysis for both tensors and +/// buffers. +static Optional +findFusableProducerForTensorOp(LinalgOp consumer, unsigned consumerIdx) { + // For now only looking for cases where the entire operand is produced by + // another Linalg structured operation. + if (!consumer.hasTensorSemantics()) + return llvm::None; + Value value = consumer.getOperation()->getOperand(consumerIdx); + if (auto linalgOp = value.getDefiningOp()) { + return LinalgDependenceGraph::LinalgDependenceGraphElem{ + {linalgOp.getOperation(), + linalgOp.getNumInputs() + value.cast().getResultNumber()}, + {consumer, consumerIdx}, + LinalgDependenceGraph::DependenceType::RAW}; + } + return llvm::None; +} + +static Optional +findFusableProducer(LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &dependenceGraph) { + if (consumer.hasBufferSemantics()) + return findFusableProducerForBufferOp(consumer, consumerIdx, + dependenceGraph); + if (consumer.hasTensorSemantics()) + return findFusableProducerForTensorOp(consumer, consumerIdx); + return llvm::None; +} + /// Find all dependences that are fusable. FusableOpDependencesTy mlir::linalg::findAllFusableDependences( ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { @@ -809,7 +851,7 @@ /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of /// `tiledOp`. static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp tiledOp, +fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, ArrayRef fusionCandidates, const FusableOpDependencesTy &fusableDependences, const std::set &fusedLoops) { @@ -823,9 +865,33 @@ } SmallVector fusedOps(fusionCandidates.size()); + DenseMap origOpToFusedOp; + origOpToFusedOp[rootOp.getOperation()] = tiledOp; for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); + LinalgOp origOp = candidate.value(); + LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); + origOpToFusedOp[origOp.getOperation()] = fusedOp; fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + // If the producer consumer operations are linalg operations on tensors, the + // dependence is due to value produced (as a return tensor) by the producer + // and used in the consumer. The returned value of the fused op needs to be + // made the operand of the tiled/fused consumer operation. By construction + // the value returned by the producer is the value used by the consumer. + for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { + if (origOp.hasTensorSemantics() && + dependence.dependenceType == + LinalgDependenceGraph::DependenceType::RAW) { + unsigned resultIndex = + dependence.dependentOpView.operandIndex - origOp.getNumInputs(); + LinalgOp consumer = + origOpToFusedOp.lookup(dependence.indexingOpView.op); + if (!consumer) + continue; + Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); + consumer.getOperation()->setOperand( + dependence.indexingOpView.operandIndex, replacementValue); + } + } builder.setInsertionPoint(fusedOp); } return fusedOps; @@ -839,14 +905,16 @@ if (ops.empty()) return llvm::None; LinalgOp rootOp = ops.back(); - for (auto op : enumerate(ops)) { - // TODO: Nothing in the fusion of sequence of ops is specific to - // buffers. This check can be removed after it is tested on tensors. - LinalgOp linalgOp = op.value(); - if (!linalgOp.hasBufferSemantics()) { - linalgOp.emitError("tile and fuse only tested for buffer operation"); - return llvm::None; - } + if (!llvm::all_of( + ops, + [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && + !llvm::all_of(ops, [](LinalgOp linalgOp) { + return linalgOp.hasTensorSemantics(); + })) { + rootOp.emitError( + "unable to fuse operations that have tensor semantics with operations " + "that have buffer semantics and viceversa."); + return llvm::None; } // TODO: Support interchange with tile + fuse. This might actually help do // better fusion. @@ -888,8 +956,9 @@ ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), + ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(), fusableDependences, ret.fusedLoopDims); + return ret; } diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -58,7 +58,7 @@ module { func @sequence_of_matmul(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, - %arg4: memref) { + %arg4: memref) { %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index %c1 = constant 1 : index @@ -131,3 +131,47 @@ // CHECK: scf.yield // CHECK: } +// ----- + +module { + func @tensor_op_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor) + -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) -> tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg3 : tensor, tensor) { + ^bb0(%arg4: f32, %arg5: f32): + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1 : tensor + } +} +// CHECK-LABEL: func @tensor_op_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor) { +// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor) { +// CHECK-DAG: %[[STARG3:.+]] = subtensor %[[ARG3]] +// CHECK-DAG: %[[STARG0:.+]] = subtensor %[[ARG0]] +// CHECK-DAG: %[[STARG1:.+]] = subtensor %[[ARG1]] +// CHECK-DAG: %[[STARG2:.+]] = subtensor %[[ARG2]] +// CHECK: %[[T0:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) +// CHECK-SAME: init(%[[STARG2]] : tensor) -> tensor +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor, tensor) +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[T1]] into %[[ARG7]] +// CHECK: scf.yield %[[RESULT]] +// CHECK: } +// CHECK: scf.yield %[[R1]] +// CHECK: } +// CHECK: return %[[R0]] diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -226,14 +226,23 @@ Aliases aliases; LinalgDependenceGraph dependenceGraph(aliases, linalgOps); OpBuilder builder(funcOp.getContext()); + linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; + if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) { + return linalgOp.hasTensorSemantics(); + })) + loopType = LinalgTilingLoopType::Loops; Optional tileAndFuseOps = tileAndFuseLinalgOps( builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( - LinalgTilingLoopType::ParallelLoops)); + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); if (!tileAndFuseOps) return signalPassFailure(); + if (linalgOps.back().hasTensorSemantics()) { + linalgOps.back().getOperation()->replaceAllUsesWith( + tileAndFuseOps->fusedLoops.front()); + } for (auto op : linalgOps) - op.erase(); + if (op.hasBufferSemantics()) + op.erase(); } }; } // namespace