diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -87,6 +87,20 @@ tilingOptions = options; return *this; } + + /// Callback to check if a value is to be yielded. + /// Parameters are `producer` which is the result of the untiled op + /// that is being fused, and the `sliceOp` that represents the slice + /// being fused (through tile and fuse). + using ControlYieldFusedProducerResultFn = + std::function; + ControlYieldFusedProducerResultFn shouldYieldFusedProducerResult = + [](OpResult, Operation *) { return false; }; + SCFTileAndFuseOptions &setControlYieldFusedProducerResultFn( + const ControlYieldFusedProducerResultFn &fn) { + shouldYieldFusedProducerResult = fn; + return *this; + } }; /// Transformation information returned after tile and fuse. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -200,7 +200,7 @@ /// } /// ``` /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. -static FailureOr> +static SmallVector yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef> tileOffsetsList, @@ -391,11 +391,9 @@ } } - FailureOr> replacementOr = yieldTiledValues( + tilingResult.replacements = yieldTiledValues( rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), resultOffsetsList, resultSizesList, tilingResult.loops); - if (failed(replacementOr)) - return rewriter.notifyMatchFailure(op, "failed to yield replacement"); if (auto dstOp = dyn_cast(tilingResult.tiledOps.back())) { @@ -408,8 +406,6 @@ innerMostLoop.getRegionIterArgs()); } - tilingResult.replacements = replacementOr.value(); - LLVM_DEBUG({ if (!tilingResult.loops.empty()) { llvm::dbgs() << "After tiled implementation :\n"; @@ -476,11 +472,9 @@ resultSizesList.push_back( b.createOrFold(loc, parallelOp->getResult(0), i)); SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); - FailureOr> replacementOr = yieldTiledValues( + SmallVector replacements = yieldTiledValues( b, identityTensor.value()->getResults(), parallelOp->getResults(), outOffsets, resultSizesList, loops); - if (failed(replacementOr)) - return b.notifyMatchFailure(op, "failed to yield replacement"); auto dstOp = cast(parallelOp); auto innerMostLoop = loops.back(); @@ -493,8 +487,7 @@ // 4. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = - op.mergeReductions(b, loc, replacementOr.value(), reductionDim); + Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); b.replaceOp(op, mergeOp->getResults()); SCFReductionTilingResult results; @@ -544,7 +537,7 @@ // 1. First tile the consumer. scf::SCFTileAndFuseResult tileAndFuseResult; - llvm::SmallDenseMap yieldedValueToResultNumber; + SmallVector toBeReturned; { FailureOr tilingResult = tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); @@ -553,13 +546,8 @@ for (auto *tiledOp : tilingResult->tiledOps) tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); - for (const auto &result : llvm::enumerate( - llvm::zip(consumer->getResults(), tilingResult->replacements))) { - tileAndFuseResult.replacements[std::get<0>(result.value())] = - std::get<1>(result.value()); - yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( - result.index())] = result.index(); - } + toBeReturned = llvm::to_vector(llvm::map_range( + consumer->getResults(), [](OpResult r) -> Value { return r; })); } // If there are no loops generated, fusion is immaterial. @@ -605,14 +593,7 @@ continue; rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the slice is for a destination operand, for example, + // 2d. If the slice is for a destination operand, for example, // // ```mlir // %0 = linalg.init @@ -683,7 +664,58 @@ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } + + // 2e. Use the callback to yield the value of the fused producer as well. + if (options.shouldYieldFusedProducerResult(fusableProducer, + candidateSliceOp)) { + SmallVector initValues; + FailureOr initValue = tensor::getOrCreateDestination( + rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); + if (succeeded(initValue)) { + SmallVector resultOffsets = + candidateSliceOp.getMixedOffsets(); + SmallVector resultSizes = + candidateSliceOp.getMixedSizes(); + SmallVector yieldedVals = yieldTiledValues( + rewriter, initValue.value(), fusedProducerValue.value(), + resultOffsets, resultSizes, tileAndFuseResult.loops); + toBeReturned.push_back(fusableProducer); + } + if (auto dstStyleProducer = + fusedProducerValue.value() + .getDefiningOp()) { + Value dstValue = + dstStyleProducer + .getDpsInitOperand(fusableProducer.getResultNumber()) + ->get(); + updateDestinationOperandsForTiledOp( + rewriter, dstValue, + tileAndFuseResult.loops.back().getRegionIterArgs().back()); + } + } + + // 2f. The operands of the fused producer might themselved be slices of + // values produced by operations that implement the `TilingInterface`. + // Add these operations to the worklist. + if (auto tiledAndFusedOp = fusedProducerValue.value().getDefiningOp()) { + tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + addCandidateSlices(tiledAndFusedOp, candidates); + } + + LLVM_DEBUG({ + if (!tileAndFuseResult.loops.empty()) { + llvm::errs() << "After fusing producer: \n"; + tileAndFuseResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); + } + + for (auto returnedValue : llvm::enumerate(toBeReturned)) { + tileAndFuseResult.replacements[returnedValue.value()] = + tileAndFuseResult.loops.front().getResult(returnedValue.index()); } + return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse -split-input-file %s | FileCheck %s func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -271,18 +271,12 @@ // CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] // CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK: %[[MATMUL:.+]] = linalg.matmul // CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : // CHECK-SAME: outs(%[[ST_ARG2]] : -// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[RHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] : -// CHECK-SAME: outs(%[[ST_ARG2_1]] : // CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] // CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] : // CHECK-SAME: outs(%[[ST_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] // CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] @@ -401,3 +395,121 @@ // CHECK-SAME: outs(%[[SLICE_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] // CHECK: scf.yield %[[UPDATE]] + +// ----- + +func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<30xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %8 = arith.maxf %arg2, %arg1 : f32 + linalg.yield %8 : f32 + } -> tensor<30xf32> + %3 = tensor.empty() : tensor<30x3xf32> + %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %5:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): + %8 = arith.subf %arg1, %arg2 : f32 + %9 = math.exp %8 : f32 + %10 = arith.addf %arg3, %9 : f32 + linalg.yield %10, %9 : f32, f32 + } -> (tensor<30xf32>, tensor<30x3xf32>) + %6 = linalg.generic { + __internal_linalg_transform__ = "reduction_sequence_fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor<30x3xf32> + return %6 : tensor<30x3xf32> +} +// CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>) +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<30xf32> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32> +// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]] +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] +// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]] +// CHECK: %[[FILL0:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]] : +// CHECK-SAME: outs(%[[FILL0]] : +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[GENERIC0]] : +// CHECK-SAME: outs(%[[FILL1]], %[[INIT1_SLICE]] : +// CHECK: %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 : +// CHECK-SAME: outs(%[[ITERARG0_SLICE]] : +// CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERTSLICE]] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs0, %c0 : tensor + %d1 = tensor.dim %rhs0, %c1 : tensor + %init0 = tensor.empty(%d0, %d1) : tensor + %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm0 = linalg.matmul {__yield_result__ = 0} + ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor + %d2 = tensor.dim %rhs1, %c1 : tensor + %init1 = tensor.empty(%d0, %d2) : tensor + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor + %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} + ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor + return %gemm0, %gemm1 : tensor, tensor +} +// CHECK: func.func @gemm_gemm_fusion_yield_both( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]]) +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] +// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]]) +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[FILL0_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[FILL0_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[FILL1_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT1_TILE]] : +// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[FILL1_TILE]] : +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -33,6 +33,8 @@ // TODO: this file should disappear and instead tests should make use of the // transform dialect. +static constexpr char yieldMarker[] = "__yield_result__"; + namespace { /// Marker used as attribute name in generated Linalg rewriting transformations. @@ -203,6 +205,26 @@ LinalgTransformationFilter filter; }; +/// Method to collect all potential fusable producer +static llvm::SmallDenseSet +collectFusableProducers(TilingInterface op) { + llvm::SmallDenseSet producers; + producers.insert(op); + SmallVector worklist; + worklist.push_back(op); + while (!worklist.empty()) { + TilingInterface currOp = worklist.pop_back_val(); + for (OpOperand &operand : currOp->getOpOperands()) { + auto producer = operand.get().getDefiningOp(); + if (producer && !producers.count(producer)) { + worklist.push_back(producer); + producers.insert(producer); + } + } + } + return producers; +} + /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive @@ -230,6 +252,8 @@ if (failed(filter.checkAndNotify(rewriter, op))) return failure(); + llvm::SmallDenseSet producers = collectFusableProducers(op); + FailureOr tileAndFuseResult = scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, options); @@ -237,12 +261,20 @@ return failure(); } // Replace the tiled op with replacements. - SmallVector replacements(op->getNumResults()); - for (const auto &result : llvm::enumerate(op->getResults())) { - replacements[result.index()] = - tileAndFuseResult->replacements.lookup(result.value()); + for (auto it : tileAndFuseResult->replacements) { + it.first.replaceUsesWithIf(it.second, [&](OpOperand &use) { + Operation *user = use.getOwner(); + // Replace use if user is + // - Not one of the untiled producers + // - Not a dim op (these are resolved through use of + // `ReifyRankedShapedTypeOpInterface`) + // - is not the outer most loop generated by tile + fuse. + return !producers.count(user) && !isa(user) && + (tileAndFuseResult->loops.empty() || + !tileAndFuseResult->loops.front()->isAncestor(user)); + }); } - rewriter.replaceOp(op, replacements); + rewriter.eraseOp(op); filter.replaceLinalgTransformationFilter( rewriter, tileAndFuseResult->tiledAndFusedOps.front()); @@ -337,6 +369,16 @@ scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( interchange); + scf::SCFTileAndFuseOptions::ControlYieldFusedProducerResultFn fn = + [&](OpResult producer, Operation * /*sliceOp*/) -> bool { + Operation *producerOp = producer.getOwner(); + auto yieldResult = producerOp->getAttrOfType(yieldMarker); + if (!yieldResult || yieldResult.getInt() != producer.getResultNumber()) + return false; + producerOp->removeAttr(yieldMarker); + return true; + }; + tileAndFuseOptions.setControlYieldFusedProducerResultFn(fn); LinalgTransformationFilter filter(StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add( @@ -385,6 +427,9 @@ // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M // dimension. addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); + // 6. Fusion of back-to-back-reduction ops + addPatternForTileAndFuse(context, patterns, "reduction_sequence_fusion", + {10}); return; } if (testLoweringToScalar) { @@ -400,6 +445,8 @@ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(tilingPatterns)))) return signalPassFailure(); + + getOperation().walk([&](Operation *op) { op->removeAttr(yieldMarker); }); } namespace mlir {