diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -85,6 +85,19 @@ tilingOptions = options; return *this; } + + /// Callback to check if a value is to be yielded. + /// Parameters are `producer` which is the result of the untiled op + /// that is being fused, and the `sliceOp` that represents the use + /// being fused. + using ControlYieldFusedProducerFn = + std::function; + ControlYieldFusedProducerFn isYieldedFusedProducer; + SCFTileAndFuseOptions & + setControlYieldFusedProducerFn(const ControlYieldFusedProducerFn &fn) { + isYieldedFusedProducer = fn; + return *this; + } }; /// Transformation information returned after tile and fuse. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -443,7 +443,7 @@ // 1. First tile the consumer. scf::SCFTileAndFuseResult tileAndFuseResult; - llvm::SmallDenseMap yieldedValueToResultNumber; + SmallVector toBeReturned; { FailureOr tilingResult = tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); @@ -451,13 +451,8 @@ return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); - for (auto result : llvm::enumerate( - llvm::zip(consumer->getResults(), tilingResult->replacements))) { - tileAndFuseResult.replacements[std::get<0>(result.value())] = - std::get<1>(result.value()); - yieldedValueToResultNumber[tilingResult->tiledOp->getResult( - result.index())] = result.index(); - } + toBeReturned = llvm::to_vector(llvm::map_range( + consumer->getResults(), [](OpResult r) -> Value { return r; })); } // If there are no loops generated, fusion is immaterial. @@ -503,14 +498,7 @@ continue; rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the slice is for a destination operand, for example, + // 2d. If the slice is for a destination operand, for example, // // ```mlir // %0 = linalg.init @@ -584,7 +572,49 @@ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } + + // 2d. Use the callback to yield the value of the fused producer as well. + if (options.isYieldedFusedProducer(fusableProducer, candidateSliceOp)) { + TilingInterface interfaceProducerOp = + cast(fusableProducer.getOwner()); + if (interfaceProducerOp) { + SmallVector resultOffsets = + candidateSliceOp.getMixedOffsets(); + SmallVector resultSizes = + candidateSliceOp.getMixedSizes(); + FailureOr> yieldedVal = + yieldTiledValues(rewriter, + interfaceProducerOp.getDestinationOperands( + rewriter)[fusableProducer.getResultNumber()], + fusedProducerValue.value(), resultOffsets, + resultSizes, tileAndFuseResult.loops); + if (succeeded(yieldedVal)) + toBeReturned.push_back(fusableProducer); + } + } + + // 2e. The operands of the fused producer might themselved be slices of + // values produced by operations that implement the `TilingInterface`. + // Add these operations to the worklist. + if (auto tiledAndFusedOp = fusedProducerValue.value().getDefiningOp()) { + tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + addCandidateSlices(tiledAndFusedOp, candidates); + } + + LLVM_DEBUG({ + if (!tileAndFuseResult.loops.empty()) { + llvm::errs() << "After fusing producer: \n"; + tileAndFuseResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); + } + + for (auto returnedValue : llvm::enumerate(toBeReturned)) { + tileAndFuseResult.replacements[returnedValue.value()] = + tileAndFuseResult.loops.front().getResult(returnedValue.index()); } + return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse -split-input-file %s | FileCheck %s func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -271,18 +271,12 @@ // CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] // CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK: %[[MATMUL:.+]] = linalg.matmul // CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : // CHECK-SAME: outs(%[[ST_ARG2]] : -// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[RHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] : -// CHECK-SAME: outs(%[[ST_ARG2_1]] : // CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] // CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] : // CHECK-SAME: outs(%[[ST_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] // CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] @@ -401,3 +395,123 @@ // CHECK-SAME: outs(%[[SLICE_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] // CHECK: scf.yield %[[UPDATE]] + +// ----- + +func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> (tensor<30x3xf32>, tensor<30x3xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = linalg.init_tensor [30] : tensor<30xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %8 = arith.maxf %arg2, %arg1 : f32 + linalg.yield %8 : f32 + } -> tensor<30xf32> + %3 = linalg.init_tensor [30, 3] : tensor<30x3xf32> + %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %5:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "reduction"], + __yield_result__ = 1} + ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): + %8 = arith.subf %arg1, %arg2 : f32 + %9 = math.exp %8 : f32 + %10 = arith.addf %arg3, %9 : f32 + linalg.yield %10, %9 : f32, f32 + } -> (tensor<30xf32>, tensor<30x3xf32>) + %6 = linalg.generic { + __internal_linalg_transform__ = "reduction_sequence_fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor<30x3xf32> + return %5#1,%6 : tensor<30x3xf32>, tensor<30x3xf32> +} +// CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>) +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [30] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [30, 3] +// CHECK: %[[RESULT:[a-zA-Z0-9]+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] +// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]] +// CHECK: %[[FILL0:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]] : +// CHECK-SAME: outs(%[[FILL0]] : +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[ITERARG1_SLICE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[GENERIC0]] : +// CHECK-SAME: outs(%[[FILL1]], %[[ITERARG1_SLICE]] : +// CHECK: %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 : +// CHECK-SAME: outs(%[[ITERARG0_SLICE]] : +// CHECK-DAG: %[[INSERTSLICE0:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK-DAG: %[[INSERTSLICE1:.+]] = tensor.insert_slice %[[GENERIC1]]#1 into %[[ITERARG1]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERTSLICE0]], %[[INSERTSLICE1]] +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs0, %c0 : tensor + %d1 = tensor.dim %rhs0, %c1 : tensor + %init0 = linalg.init_tensor [%d0, %d1] : tensor + %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm0 = linalg.matmul {__yield_result__ = 0} + ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor + %d2 = tensor.dim %rhs1, %c1 : tensor + %init1 = linalg.init_tensor [%d0, %d2] : tensor + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor + %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} + ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor + return %gemm0, %gemm1 : tensor, tensor +} +// CHECK: func.func @gemm_gemm_fusion_yield_both( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]] +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[FILL0_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[FILL0_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[FILL1_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT1_TILE]] : +// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[FILL1_TILE]] : +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -31,6 +31,8 @@ using namespace mlir; +static constexpr char yieldMarker[] = "__yield_result__"; + namespace { /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using @@ -79,6 +81,26 @@ linalg::LinalgTransformationFilter filter; }; +/// Method to collect all potential fusable producer +static llvm::SmallDenseSet +collectFusableProducers(TilingInterface op) { + llvm::SmallDenseSet producers; + producers.insert(op); + SmallVector worklist; + worklist.push_back(op); + while (!worklist.empty()) { + TilingInterface currOp = worklist.pop_back_val(); + for (OpOperand &operand : currOp->getOpOperands()) { + auto producer = operand.get().getDefiningOp(); + if (producer && !producers.count(producer)) { + worklist.push_back(producer); + producers.insert(producer); + } + } + } + return producers; +} + /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive @@ -108,6 +130,8 @@ if (failed(filter.checkAndNotify(rewriter, op))) return failure(); + llvm::SmallDenseSet producers = collectFusableProducers(op); + FailureOr tileAndFuseResult = scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, options); @@ -115,12 +139,20 @@ return failure(); } // Replace the tiled op with replacements. - SmallVector replacements(op->getNumResults()); - for (auto result : llvm::enumerate(op->getResults())) { - replacements[result.index()] = - tileAndFuseResult->replacements.lookup(result.value()); + for (auto it : tileAndFuseResult->replacements) { + it.first.replaceUsesWithIf(it.second, [&](OpOperand &use) { + Operation *user = use.getOwner(); + // Replace use if user is + // - Not one of the untiled producers + // - Not a dim op (these are resolved through use of + // `ReifyRankedShapedTypeOpInterface`) + // - is not the outer most loop generated by tile + fuse. + return !producers.count(user) && !isa(user) && + (tileAndFuseResult->loops.empty() || + !tileAndFuseResult->loops.front()->isAncestor(user)); + }); } - rewriter.replaceOp(op, replacements); + rewriter.eraseOp(op); filter.replaceLinalgTransformationFilter( rewriter, tileAndFuseResult->tiledAndFusedOps.front()); @@ -215,6 +247,16 @@ scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( interchange); + scf::SCFTileAndFuseOptions::ControlYieldFusedProducerFn fn = + [&](OpResult producer, Operation * /*sliceOp*/) -> bool { + Operation *producerOp = producer.getOwner(); + auto yieldResult = producerOp->getAttrOfType(yieldMarker); + if (!yieldResult || yieldResult.getInt() != producer.getResultNumber()) + return false; + producerOp->removeAttr(yieldMarker); + return true; + }; + tileAndFuseOptions.setControlYieldFusedProducerFn(fn); linalg::LinalgTransformationFilter filter( StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); patterns.add( @@ -263,6 +305,9 @@ // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M // dimension. addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); + // 6. Fusion of back-to-back-reduction ops + addPatternForTileAndFuse(context, patterns, "reduction_sequence_fusion", + {10}); return; } if (testLoweringToScalar) { @@ -278,6 +323,8 @@ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(tilingPatterns)))) return signalPassFailure(); + + getOperation().walk([&](Operation *op) { op->removeAttr(yieldMarker); }); } namespace mlir {