diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -732,45 +732,6 @@ return fusableLoops; } -// /// For `consumer` with tensor semantics, find the Linalg operation on -// tensors -// /// producer the operand at position `consumerIdx`. This is a simple use-def -// /// chain using the SSA value, but returned as an element of the -// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and -// /// buffers. -// static Optional -// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) { -// // For now only looking for cases where the operand is produced by another -// // Linalg structured operation. -// LinalgOp consumer = cast(consumerOpOperand.getOwner()); -// if (!consumer || !consumer.hasTensorSemantics()) -// return llvm::None; -// unsigned consumerIdx = consumerOpOperand.getOperandNumber(); -// Value value = consumerOpOperand.get(); -// if (auto linalgOp = value.getDefiningOp()) { -// return LinalgDependenceGraph::LinalgDependenceGraphElem{ -// &(linalgOp -// .getOutputOpOperands()[value.cast().getResultNumber()]), -// &(consumer.getInputOpOperands()[consumerIdx]), -// LinalgDependenceGraph::DependenceType::RAW}; -// } -// return llvm::None; -// } - -// static Optional -// findFusableProducer(OpOperand &consumerOpOperand, -// const LinalgDependenceGraph &dependenceGraph) { -// LinalgOp consumer = cast(consumerOpOperand.getOwner()); -// if (!consumer) -// return llvm::None; -// if (consumer.hasBufferSemantics()) -// return findFusableProducerForBufferOp(consumerOpOperand, -// dependenceGraph); -// if (consumer.hasTensorSemantics()) -// return findFusableProducerForTensorOp(consumerOpOperand); -// return llvm::None; -// } - /// Find all dependences that are fusable. FusableOpDependencesTy mlir::linalg::findAllFusableDependences( ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -283,6 +283,19 @@ return success(); } +static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { + if (tiledOp.loops.empty()) + return tiledOp.op.getOperation()->getResults(); + return tiledOp.loops.front()->getResults(); +} + +static ValueRange +getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { + if (tiledAndFusedOp.fusedLoops.empty()) + return tiledAndFusedOp.op.getOperation()->getResults(); + return tiledAndFusedOp.fusedLoops.front()->getResults(); +} + mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( StringRef opName, MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, @@ -301,8 +314,6 @@ return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - if (!linalgOp.hasBufferSemantics()) - return failure(); DenseSet producers; producers.insert(linalgOp); @@ -359,9 +370,11 @@ tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); if (!unfusedTiledOp) return failure(); - rewriter.eraseOp(tiledAndFusedOps->op); + rewriter.replaceOp(tiledAndFusedOps->op, + getTiledOpResult(unfusedTiledOp.getValue())); tiledAndFusedOps->op = unfusedTiledOp->op; } + op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); marker.replaceLinalgTransformationFilter(rewriter, tiledAndFusedOps->op.getOperation()); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -0,0 +1,142 @@ +// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s + +module { + func @matmul_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor, + %arg4: tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor // + %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} + ins(%0, %arg3 : tensor, tensor) + outs(%arg4 : tensor) -> tensor // + return %1 : tensor + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)> +// CHECK: func @matmul_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[M_2:.+]] = dim %[[ARG6]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[M_2]], %[[IV0]]) +// CHECK: %[[N3:.+]] = dim %[[ARG6]], %[[C1]] +// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] +// CHECK: %[[N2:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[N1:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N1]]] +// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0] +// CHECK-SAME: [%[[N1]], %[[N2]]] +// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N2]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) +// CHECK: %[[N3_2:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]] +// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor) { +// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = +// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]] +// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor) { +// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]] +// CHECK: %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]] +// CHECK: %[[N2_3:.+]] = dim %[[ARG3]], %[[C0]] +// CHECK: %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]] +// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]] +// CHECK: %[[ST_ARG3:.+]] = subtensor %[[ARG3]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_N2_2]], %[[TILE_N3]]] +// CHECK: %[[M_4:.+]] = dim %[[ARG10]], %[[C0]] +// CHECK: %[[N3_3:.+]] = dim %[[ARG10]], %[[C1]] +// CHECK: %[[TILE_N3_2:.+]] = affine.min #[[MAP4]](%[[N3_3]], %[[IV1]]) +// CHECK: %[[ST_ARG4:.+]] = subtensor %[[ARG10]][0, %[[IV1]]] +// CHECK-SAME: [%[[M_4]], %[[TILE_N3_2]]] +// CHECK: %[[ST_RESULT:.+]] = linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" +// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]] +// CHECK-SAME: : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG4]] : tensor) +// CHECK: %[[UPDATE1:.+]] = subtensor_insert %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3_2]]] +// CHECK: scf.yield %[[UPDATE1]] +// CHECK: } +// CHECK: scf.yield %[[YIELD1]] +// CHECK: } +// CHECK: %[[UPDATE0:.+]] = subtensor_insert %[[YIELD0]] into +// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]] +// CHECK: scf.yield %[[UPDATE0]] +// CHECK: } +// CHECK: return %[[RESULT]] + +// ----- + +module { + func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg2, %c0 : tensor + %1 = dim %arg2, %c1 : tensor + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %3 = dim %2, %c0 : tensor + %4 = dim %2, %c1 : tensor + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "transpose_fusion"} + ins(%2, %2 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %7 = addf %arg3, %arg4 : f32 + linalg.yield %7 : f32 + } -> tensor + return %6 : tensor + } +} +// CHECK: func @matmul_plus_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) +// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) +// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] +// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, %[[IV1]]] +// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] +// CHECK-SAME: : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) +// CHECK: %[[ST_RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]] : tensor) +// CHECK-SAME: outs(%[[ST_ARG6]] : tensor) +// CHECK: %[[UPDATE:.+]] = subtensor_insert %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -20,30 +20,14 @@ using namespace mlir; using namespace mlir::linalg; -namespace { -struct TestLinalgFusionTransforms - : public PassWrapper { - TestLinalgFusionTransforms() = default; - TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnFunction() override; -}; -} // namespace - +template static void fillFusionPatterns(MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, OwningRewritePatternList &patterns) { patterns.insert, LinalgTileAndFusePattern>( context, dependenceGraph, - LinalgTilingOptions() - .setTileSizes({32, 64, 16}) - .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({2}), LinalgTransformationFilter( Identifier::get("basic_fusion", context), @@ -57,9 +41,7 @@ patterns.insert>( context, dependenceGraph, - LinalgTilingOptions() - .setTileSizes({32, 64, 16}) - .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0}), LinalgTransformationFilter(Identifier::get("lhs_fusion", context), Identifier::get("after_lhs_fusion", context)), @@ -72,9 +54,7 @@ patterns.insert>( context, dependenceGraph, - LinalgTilingOptions() - .setTileSizes({32, 64, 16}) - .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({1}), LinalgTransformationFilter(Identifier::get("rhs_fusion", context), Identifier::get("after_rhs_fusion", context)), @@ -87,9 +67,7 @@ patterns.insert>( context, dependenceGraph, - LinalgTilingOptions() - .setTileSizes({32, 64, 16}) - .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0, 2}), LinalgTransformationFilter( Identifier::get("two_operand_fusion", context), @@ -103,8 +81,7 @@ patterns.insert>( context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64}).setLoopType( - LinalgTilingLoopType::ParallelLoops), + LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), LinalgFusionOptions().setIndicesToFuse({0, 1}), LinalgTransformationFilter( Identifier::get("transpose_fusion", context), @@ -117,18 +94,30 @@ Identifier::get("after_transpose_fusion_original", context))); } -static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) { - OwningRewritePatternList fusionPatterns; - Aliases alias; - LinalgDependenceGraph dependenceGraph = - LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); - fillFusionPatterns(context, dependenceGraph, fusionPatterns); - applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); -} +namespace { +template +struct TestLinalgFusionTransforms + : public PassWrapper, FunctionPass> { + TestLinalgFusionTransforms() = default; + TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} -void TestLinalgFusionTransforms::runOnFunction() { - applyFusionPatterns(&getContext(), getFunction()); -} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + MLIRContext *context = &this->getContext(); + FuncOp funcOp = this->getFunction(); + OwningRewritePatternList fusionPatterns; + Aliases alias; + LinalgDependenceGraph dependenceGraph = + LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); + fillFusionPatterns(context, dependenceGraph, fusionPatterns); + applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); + } +}; +} // namespace static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { OpBuilder b(f); @@ -237,7 +226,7 @@ LinalgDependenceGraph dependenceGraph(aliases, linalgOps); OpBuilder builder(funcOp.getContext()); linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; - if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) { + if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { return linalgOp.hasTensorSemantics(); })) loopType = LinalgTilingLoopType::Loops; @@ -260,10 +249,17 @@ namespace mlir { namespace test { void registerTestLinalgFusionTransforms() { - PassRegistration testFusionTransformsPass( + PassRegistration> testFusionTransformsPass( "test-linalg-fusion-transform-patterns", "Test Linalg fusion transformation patterns by applying them greedily."); } +void registerTestLinalgTensorFusionTransforms() { + PassRegistration> + testTensorFusionTransformsPass( + "test-linalg-tensor-fusion-transform-patterns", + "Test Linalg on tensor fusion transformation " + "patterns by applying them greedily."); +} void registerTestLinalgGreedyFusion() { PassRegistration testFusionTransformsPass( "test-linalg-greedy-fusion", diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,6 +74,7 @@ void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgFusionTransforms(); +void registerTestLinalgTensorFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); void registerTestLinalgTileAndFuseSequencePass(); @@ -145,6 +146,7 @@ test::registerTestInterfaces(); test::registerTestLinalgCodegenStrategy(); test::registerTestLinalgFusionTransforms(); + test::registerTestLinalgTensorFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); test::registerTestLinalgTileAndFuseSequencePass();