diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -36,6 +36,10 @@ createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); std::unique_ptr> createLinalgPromotionPass(); +/// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel` +/// loops and memref.load/memref.store accesses. +std::unique_ptr> createConvertLinalgTiledLoopsToSCFPass(); + /// Create a pass to convert Linalg operations to scf.for loops and /// memref.load/memref.store accesses. std::unique_ptr> createConvertLinalgToLoopsPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -58,6 +58,17 @@ let dependentDialects = ["AffineDialect", "memref::MemRefDialect"]; } +def LinalgLowerTiledLoopsToSCF + : FunctionPass<"convert-linalg-tiled-loops-to-scf"> { + let summary = "Lower linalg tiled loops to SCF loops and parallel loops"; + let constructor = "mlir::createConvertLinalgTiledLoopsToSCFPass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "scf::SCFDialect", + "AffineDialect" + ]; +} + def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; @@ -76,16 +87,6 @@ ]; } -def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> { - let summary = "Bufferize the linalg dialect"; - let constructor = "mlir::createLinalgBufferizePass()"; - let dependentDialects = [ - "linalg::LinalgDialect", - "AffineDialect", - "memref::MemRefDialect" - ]; -} - def LinalgLowerToParallelLoops : FunctionPass<"convert-linalg-to-parallel-loops"> { let summary = "Lower the operations from the linalg dialect into parallel " @@ -99,6 +100,16 @@ ]; } +def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> { + let summary = "Bufferize the linalg dialect"; + let constructor = "mlir::createLinalgBufferizePass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "AffineDialect", + "memref::MemRefDialect" + ]; +} + def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { let summary = "Promote subview ops to local buffers"; let constructor = "mlir::createLinalgPromotionPass()"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -555,7 +555,7 @@ } }; -struct TiledLoopPattern : public OpRewritePattern { +struct TiledLoopToSCFPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, @@ -597,7 +597,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp) { MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - patterns.add, TiledLoopPattern>(context); + patterns.add>(context); memref::DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); @@ -668,8 +668,23 @@ lowerLinalgToLoopsImpl(getFunction()); } }; + +struct LowerTiledLoopsToSCF + : public LinalgLowerTiledLoopsToSCFBase { + void runOnFunction() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; } // namespace +std::unique_ptr> +mlir::createConvertLinalgTiledLoopsToSCFPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createConvertLinalgToLoopsPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1522,78 +1522,3 @@ // CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref - - -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - -func @tiled_loop_to_parallel(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref<192x192xf32>) { - %cst = constant 0.000000e+00 : f32 - %c24 = constant 24 : index - %c16 = constant 16 : index - %c0 = constant 0 : index - %c192 = constant 192 : index - - linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref<192x192xf32>) { - %0 = affine.min #map0(%i) - %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] - : memref<192x192xf32> to memref - %2 = affine.min #map2(%j) - %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] - : memref<192x192xf32> to memref<192x?xf32, #map1> - %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] - : memref<192x192xf32> to memref - linalg.fill(%4, %cst) : memref, f32 - linalg.matmul ins(%1, %3 : memref, - memref<192x?xf32, #map1>) - outs(%4 : memref) - linalg.yield - } - return -} - -// CHECKLOOP-LABEL: @tiled_loop_to_parallel -// CHECKLOOP-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, -// CHECKLOOP-SAME: %[[C:.*]]: memref<192x192xf32>) { -// CHECKLOOP: %[[C24:.*]] = constant 24 : index -// CHECKLOOP: %[[C16:.*]] = constant 16 : index -// CHECKLOOP: %[[C192:.*]] = constant 192 : index -// CHECKLOOP: %[[C0:.*]] = constant 0 : index -// CHECKLOOP: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] { -// CHECKLOOP: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] { -// CHECKLOOP: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] -// CHECKLOOP: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] -// CHECKLOOP: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] - - -func @tiled_loop_to_for(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref) { - %c24 = constant 24 : index - %c16 = constant 16 : index - %c0 = constant 0 : index - %c192 = constant 192 : index - %cst = constant 0.000000e+00 : f32 - - linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref) - iterators["reduction", "reduction"] { - linalg.fill(%A_, %cst) : memref<192x192xf32>, f32 - linalg.yield - } - return -} - -// CHECKLOOP-LABEL: @tiled_loop_to_for -// CHECKLOOP: %[[C24:.*]] = constant 24 : index -// CHECKLOOP: %[[C16:.*]] = constant 16 : index -// CHECKLOOP: %[[C192:.*]] = constant 192 : index -// CHECKLOOP: %[[C0:.*]] = constant 0 : index -// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] -// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]] diff --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tiled-loops.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf | FileCheck %s + + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func @tiled_loop(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref<192x192xf32>) { + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref<192x192xf32>) { + %0 = affine.min #map0(%i) + %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] + : memref<192x192xf32> to memref + %2 = affine.min #map2(%j) + %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] + : memref<192x192xf32> to memref<192x?xf32, #map1> + %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] + : memref<192x192xf32> to memref + linalg.fill(%4, %cst) : memref, f32 + linalg.matmul ins(%1, %3 : memref, + memref<192x?xf32, #map1>) + outs(%4 : memref) + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop +// CHECK-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, +// CHECK-SAME: %[[C:.*]]: memref<192x192xf32>) { +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C192:.*]] = constant 192 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] { +// CHECK: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] +// CHECK: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] +// CHECK: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] +// CHECK: linalg.fill +// CHECK: linalg.matmul + + +func @tiled_loop_reduction(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref) { + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %cst = constant 0.000000e+00 : f32 + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref) + iterators["reduction", "reduction"] { + linalg.fill(%A_, %cst) : memref<192x192xf32>, f32 + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop_reduction +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C192:.*]] = constant 192 : index +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]] +// CHECK: linalg.fill