diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -555,6 +555,41 @@ } }; +struct TiledLoopPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TiledLoopOp tiledLoop, + PatternRewriter &rewriter) const override { + Location loc = tiledLoop.getLoc(); + + // Fail conversion if the `tiled_loop` has not been bufferized. + if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) { + return arg.getType().isa(); + })) + return failure(); + + // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the + // iterator type. + scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(), + tiledLoop.upperBound(), tiledLoop.step(), + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Move body without its terminator. + SmallVector newBlockArgs; + newBlockArgs.append(ivs.begin(), ivs.end()); + newBlockArgs.append(tiledLoop.inputs().begin(), + tiledLoop.inputs().end()); + newBlockArgs.append(tiledLoop.outputs().begin(), + tiledLoop.outputs().end()); + Block *newBody = rewriter.getInsertionBlock(); + rewriter.mergeBlocks(tiledLoop.getBody(), newBody, + newBlockArgs); + rewriter.eraseOp(newBody->getTerminator()); + }); + rewriter.eraseOp(tiledLoop); + return success(); + } +}; + struct FoldAffineOp; } // namespace @@ -562,7 +597,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp) { MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - patterns.add>(context); + patterns.add, TiledLoopPattern>(context); memref::DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1522,3 +1522,78 @@ // CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref + + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func @tiled_loop_to_parallel(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref<192x192xf32>) { + %cst = constant 0.000000e+00 : f32 + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref<192x192xf32>) { + %0 = affine.min #map0(%i) + %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] + : memref<192x192xf32> to memref + %2 = affine.min #map2(%j) + %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] + : memref<192x192xf32> to memref<192x?xf32, #map1> + %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] + : memref<192x192xf32> to memref + linalg.fill(%4, %cst) : memref, f32 + linalg.matmul ins(%1, %3 : memref, + memref<192x?xf32, #map1>) + outs(%4 : memref) + linalg.yield + } + return +} + +// CHECKLOOP-LABEL: @tiled_loop_to_parallel +// CHECKLOOP-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, +// CHECKLOOP-SAME: %[[C:.*]]: memref<192x192xf32>) { +// CHECKLOOP: %[[C24:.*]] = constant 24 : index +// CHECKLOOP: %[[C16:.*]] = constant 16 : index +// CHECKLOOP: %[[C192:.*]] = constant 192 : index +// CHECKLOOP: %[[C0:.*]] = constant 0 : index +// CHECKLOOP: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] { +// CHECKLOOP: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] { +// CHECKLOOP: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] +// CHECKLOOP: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] +// CHECKLOOP: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] + + +func @tiled_loop_to_for(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, + %C: memref) { + %c24 = constant 24 : index + %c16 = constant 16 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %cst = constant 0.000000e+00 : f32 + + linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) + ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) + outs (%C_ = %C: memref) + iterators["reduction", "reduction"] { + linalg.fill(%A_, %cst) : memref<192x192xf32>, f32 + linalg.yield + } + return +} + +// CHECKLOOP-LABEL: @tiled_loop_to_for +// CHECKLOOP: %[[C24:.*]] = constant 24 : index +// CHECKLOOP: %[[C16:.*]] = constant 16 : index +// CHECKLOOP: %[[C192:.*]] = constant 192 : index +// CHECKLOOP: %[[C0:.*]] = constant 0 : index +// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] +// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]