diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -732,6 +732,22 @@ if (it == outputs().end()) return nullptr; return it.getBase(); } + + /// Return whether the op has only MemRef input and outputs. + bool hasBufferSemantics() { + Operation* op = this->getOperation(); + return op->getNumResults() == 0 && + llvm::all_of(op->getOpOperands(), [&](OpOperand & operand) { + return !operand.get().getType().template isa() || + operand.get().getType().template isa(); + }); + } + + /// Return whether the loop dimension is parallel or not. + bool isParallelDimension(unsigned dim) { + StringAttr attr = this->iterator_types()[dim].cast(); + return attr.getValue() == getParallelIteratorTypeName(); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -480,36 +480,67 @@ } }; +/// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected +/// into an scf.parallel loop and all sequential dimensions will result in the +/// nested scf.for loop nest. The pattern assumes that a tiled loop with +/// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It +/// is true for the tiling that is currently suppported by Linalg. struct TiledLoopToSCFPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, PatternRewriter &rewriter) const override { - Location loc = tiledLoop.getLoc(); - // Fail conversion if the `tiled_loop` has not been bufferized. - if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) { - return arg.getType().isa(); - })) + if (!tiledLoop.hasBufferSemantics()) return failure(); - // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the - // iterator type. - scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(), - tiledLoop.upperBound(), tiledLoop.step(), - [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Move body without its terminator. - SmallVector newBlockArgs; - newBlockArgs.append(ivs.begin(), ivs.end()); - newBlockArgs.append(tiledLoop.inputs().begin(), - tiledLoop.inputs().end()); - newBlockArgs.append(tiledLoop.outputs().begin(), - tiledLoop.outputs().end()); - Block *newBody = rewriter.getInsertionBlock(); - rewriter.mergeBlocks(tiledLoop.getBody(), newBody, - newBlockArgs); - rewriter.eraseOp(newBody->getTerminator()); - }); + // Collect loop control parameters for parallel and sequential dimensions. + SmallVector seqLBs, seqUBs, seqSteps, seqIVs; + SmallVector parLBs, parUBs, parSteps, parIVs; + for (auto en : llvm::enumerate( + llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(), + tiledLoop.step(), tiledLoop.getInductionVars()))) { + Value lb, ub, step, iv; + std::tie(lb, ub, step, iv) = en.value(); + if (tiledLoop.isParallelDimension(en.index())) { + parLBs.push_back(lb); + parUBs.push_back(ub); + parSteps.push_back(step); + parIVs.push_back(iv); + } else { + seqLBs.push_back(lb); + seqUBs.push_back(ub); + seqSteps.push_back(step); + seqIVs.push_back(iv); + } + } + + Location loc = tiledLoop.getLoc(); + auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc, + ValueRange ivs) { + BlockAndValueMapping bvm; + bvm.map(parIVs, ivs); + bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs()); + bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs()); + + // If not all dimensions of the tiled loop are parallel, an scf.for loop + // nest is generated. + if (!seqIVs.empty()) { + scf::LoopNest nest = + scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps, + [&](OpBuilder &builder, Location loc, + ValueRange ivs) { bvm.map(seqIVs, ivs); }); + builder.setInsertionPointToStart(nest.loops.back().getBody()); + } + for (auto &op : tiledLoop.getBody()->without_terminator()) + builder.clone(op, bvm); + }; + + if (parIVs.empty()) + generateForLoopNestAndCloneBody(rewriter, loc, llvm::None); + else + rewriter.create(loc, parLBs, parUBs, parSteps, + generateForLoopNestAndCloneBody); rewriter.eraseOp(tiledLoop); return success(); } diff --git a/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir b/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf --split-input-file | FileCheck %s + + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func @tiled_loop(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref<192x192xf32>) { + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref<192x192xf32>) { + %0 = affine.min #map0(%i) + %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] + : memref<192x192xf32> to memref + %2 = affine.min #map2(%j) + %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] + : memref<192x192xf32> to memref<192x?xf32, #map1> + %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] + : memref<192x192xf32> to memref + linalg.fill(%cst, %4) : f32, memref + linalg.matmul ins(%1, %3 : memref, + memref<192x?xf32, #map1>) + outs(%4 : memref) + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop +// CHECK-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, +// CHECK-SAME: %[[C:.*]]: memref<192x192xf32>) { +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C192:.*]] = constant 192 : index +// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) { +// CHECK: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] +// CHECK: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] +// CHECK: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] +// CHECK: linalg.fill +// CHECK: linalg.matmul + +// ----- + +func @tiled_loop_reduction(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref) { + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %cst = constant 0.000000e+00 : f32 + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref) + iterators["reduction", "reduction"] { + linalg.fill(%cst, %A_) : f32, memref<192x192xf32> + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop_reduction +// CHECK: %[[C24:.*]] = constant 24 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C192:.*]] = constant 192 : index +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] +// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]] +// CHECK: linalg.fill + +// ----- + +#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)> +#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)> + +func @tiled_loop_row_reduction(%A: memref<10x8xf32>, + %B: memref<8xf32>) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %c10 = constant 10 : index + %cst = constant 0.000000e+00 : f32 + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4) + ins (%A_ = %A: memref<10x8xf32>) + outs (%B_ = %B: memref<8xf32>) + iterators["reduction", "parallel"] { + %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1] + : memref<10x8xf32> to memref<2x4xf32, #strided_2d> + %B_sub = memref.subview %B_[%j][4][1] + : memref<8xf32> to memref<4xf32, #strided_1d> + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>], + iterator_types = ["reduction", "parallel"]} + ins(%A_sub : memref<2x4xf32, #strided_2d>) + outs(%B_sub : memref<4xf32, #strided_1d>) { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b: f32 + linalg.yield %0 : f32 + } + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop_row_reduction + +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK-DAG: %[[C10:.*]] = constant 10 : index + +// CHECK: scf.parallel (%[[J:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]]) +// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C10]] step %[[C2]] +// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1] +// CHECK-SAME: : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}> +// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[J]]] [4] [1] +// CHECK-SAME: : memref<8xf32> to memref<4xf32, #map{{[0-9]+}}> + +// ----- + +#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)> +#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)> + +func @tiled_loop_col_reduction(%A: memref<10x8xf32>, + %B: memref<10xf32>) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %c10 = constant 10 : index + %cst = constant 0.000000e+00 : f32 + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4) + ins (%A_ = %A: memref<10x8xf32>) + outs (%B_ = %B: memref<10xf32>) + iterators["parallel", "reduction"] { + %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1] + : memref<10x8xf32> to memref<2x4xf32, #strided_2d> + %B_sub = memref.subview %B_[%i][2][1] + : memref<10xf32> to memref<2xf32, #strided_1d> + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)>], + iterator_types = ["parallel", "reduction"]} + ins(%A_sub : memref<2x4xf32, #strided_2d>) + outs(%B_sub : memref<2xf32, #strided_1d>) { + ^bb(%a: f32, %b: f32) : + %0 = addf %a, %b: f32 + linalg.yield %0 : f32 + } + linalg.yield + } + return +} + +// CHECK-LABEL: @tiled_loop_col_reduction + +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C4:.*]] = constant 4 : index +// CHECK-DAG: %[[C8:.*]] = constant 8 : index +// CHECK-DAG: %[[C10:.*]] = constant 10 : index + +// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C10]]) step (%[[C2]]) +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] +// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1] +// CHECK-SAME: : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}> +// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]]] [2] [1] +// CHECK-SAME: : memref<10xf32> to memref<2xf32, #map{{[0-9]+}}> diff --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/tiled-loops.mlir +++ /dev/null @@ -1,79 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf | FileCheck %s - - -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - -func @tiled_loop(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref<192x192xf32>) { - %cst = constant 0.000000e+00 : f32 - %c24 = constant 24 : index - %c16 = constant 16 : index - %c0 = constant 0 : index - %c192 = constant 192 : index - - linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref<192x192xf32>) { - %0 = affine.min #map0(%i) - %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] - : memref<192x192xf32> to memref - %2 = affine.min #map2(%j) - %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] - : memref<192x192xf32> to memref<192x?xf32, #map1> - %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] - : memref<192x192xf32> to memref - linalg.fill(%cst, %4) : f32, memref - linalg.matmul ins(%1, %3 : memref, - memref<192x?xf32, #map1>) - outs(%4 : memref) - linalg.yield - } - return -} - -// CHECK-LABEL: @tiled_loop -// CHECK-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, -// CHECK-SAME: %[[C:.*]]: memref<192x192xf32>) { -// CHECK: %[[C24:.*]] = constant 24 : index -// CHECK: %[[C16:.*]] = constant 16 : index -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C192:.*]] = constant 192 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] { -// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] { -// CHECK: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] -// CHECK: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] -// CHECK: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] -// CHECK: linalg.fill -// CHECK: linalg.matmul - - -func @tiled_loop_reduction(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref) { - %c24 = constant 24 : index - %c16 = constant 16 : index - %c0 = constant 0 : index - %c192 = constant 192 : index - %cst = constant 0.000000e+00 : f32 - - linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref) - iterators["reduction", "reduction"] { - linalg.fill(%cst, %A_) : f32, memref<192x192xf32> - linalg.yield - } - return -} - -// CHECK-LABEL: @tiled_loop_reduction -// CHECK: %[[C24:.*]] = constant 24 : index -// CHECK: %[[C16:.*]] = constant 16 : index -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C192:.*]] = constant 192 : index -// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] -// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]] -// CHECK: linalg.fill