diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1036,6 +1036,31 @@ PatternRewriter &rewriter) const override; }; +/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly +/// into a TiledLoopOp where the step divides the iteration space evenly, +/// followed by another TiledLoopOp for the last (partial) iteration (if any). +/// This transformation is called "loop peeling". +/// +/// This function peels the `idx`-th loop of the TiledLoopOp. To tile all loops +/// in the loop nest, this function must be called multiple times. +/// +/// After loop peeling, this function tries to simplify/canonicalize affine.min +/// and affine.max ops in the body of the two TiledLoopOps. For more details, +/// refer to `mlir::scf::peelAndCanonicalizeForLoop`. +/// +/// The return value indicates whether the loop was rewritten or not. Loops are +/// not rewritten if: +/// * Loop step size is 1 or +/// * Loop bounds and step size are static, and step already divides the +/// iteration space evenly. +/// +/// Note: This function rewrites the given TiledLoopOp in-place and clones the +/// TileLoopOp operation for the last iteration. It replaces all uses of the +/// unpeeled TiledLoopOp with the results of the newly generated TiledLoopOp. +LogicalResult peelAndCanonicalizeTiledLoop(RewriterBase &rewriter, + TiledLoopOp loopOp, int64_t idx, + TiledLoopOp &result); + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -111,6 +111,24 @@ LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, scf::IfOp &ifOp); +/// Try to simplify a min/max operation `op` after loop peeling. This function +/// can simplify min/max operations such as (ub is the previous upper bound of +/// the unpeeled loop): +/// ``` +/// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> +/// %r = affine.min #affine.min #map(%iv)[%step, %ub] +/// ``` +/// and rewrites them into (in the case the peeled loop): +/// ``` +/// %r = %step +/// ``` +/// min/max operations inside the partial iteration are rewritten in a similar +/// way. +LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, + AffineMap map, ValueRange operands, + bool isMin, Value iv, Value ub, Value step, + bool insideLoop); + /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) /// step (%arg4, %arg5) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -633,6 +634,119 @@ }; } // namespace +/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly +/// into two TiledLoopOps: One where the step divides the iteration space +/// evenly, followed another one for the last (partial) iteration (if any). This +/// function only rewrites the `idx`-th loop of the loop nest represented by +/// the TiledLoopOp. To peel the entire loop nest, this function must be called +/// multiple times. +/// +/// This function rewrites the given TiledLoopOp in-place and creates a new +/// TiledLoopOp for the last iteration. It replaces all uses of the original +/// TiledLoopOp with the results of the newly generated one. +/// +/// The newly generated TiledLoopOp is returned via `result`. The boundary +/// at which the loop is split (new upper bound) is returned via `splitBound`. +/// The return value indicates whether the TiledLoopOp was rewritten or not. +static LogicalResult peelTiledLoop(RewriterBase &b, TiledLoopOp loopOp, + int64_t idx, TiledLoopOp &result, + Value &splitBound) { + Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx], + step = loopOp.step()[idx]; + auto ubInt = getConstantIntValue(ub); + + auto loc = loopOp.getLoc(); + AffineExpr exprLb, exprUb, exprStep; + bindSymbols(b.getContext(), exprLb, exprUb, exprStep); + // New upper bound: %ub - (%ub - %lb) mod %step + auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)}); + SmallVector operands{lb, ub, step}; + mlir::canonicalizeMapAndOperands(&modMap, &operands); + modMap = mlir::simplifyAffineMap(modMap); + RewriterBase::InsertionGuard guard(b); + b.setInsertionPoint(loopOp); + splitBound = b.createOrFold(loc, modMap, operands); + // No specialization necessary if step already divides upper bound evenly. + if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) + return failure(); + + // Create remainder loop. + b.setInsertionPointAfter(loopOp); + auto remainderLoop = cast(b.clone(*loopOp.getOperation())); + loopOp.replaceAllUsesWith(remainderLoop->getResults()); + // Outputs: Take tensors from main loop's results. Take memrefs from main + // loop's outputs. + SmallVector remainderOutputs; + for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) { + remainderOutputs.push_back(loopOp.outputs()[o].getType().isa() + ? loopOp.outputs()[o] + : loopOp->getResult(t++)); + } + remainderLoop.outputsMutable().assign(remainderOutputs); + + // Set new loop bounds. + b.updateRootInPlace(loopOp, [&]() { + SmallVector ubs = loopOp.upperBound(); + ubs[idx] = splitBound; + loopOp.upperBoundMutable().assign(ubs); + }); + SmallVector lbs = remainderLoop.lowerBound(); + lbs[idx] = splitBound; + remainderLoop.lowerBoundMutable().assign(lbs); + + result = remainderLoop; + return success(); +} + +template +static void +rewriteAffineOpAfterPeeling(RewriterBase &rewriter, TiledLoopOp mainLoop, + TiledLoopOp remainderLoop, Value mainIv, + Value remainderIv, Value ub, Value step) { + mainLoop.walk([&](OpTy affineOp) { + AffineMap map = affineOp.getAffineMap(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, + affineOp.operands(), IsMin, mainIv, ub, + step, /*insideLoop=*/true); + }); + remainderLoop.walk([&](OpTy affineOp) { + AffineMap map = affineOp.getAffineMap(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, + affineOp.operands(), IsMin, remainderIv, + ub, step, /*insideLoop=*/false); + }); +} + +LogicalResult mlir::linalg::peelAndCanonicalizeTiledLoop(RewriterBase &rewriter, + TiledLoopOp loopOp, + int64_t idx, + TiledLoopOp &result) { + int64_t numLoops = loopOp.iterator_types().size(); + if (idx < 0 || numLoops <= idx) + return failure(); + // Only parallel iterator supported. + if (!isParallelIterator(loopOp.iterator_types()[idx])) + return failure(); + + Value ub = loopOp.upperBound()[idx]; + TiledLoopOp remainderLoop; + Value splitBound; + if (failed(peelTiledLoop(rewriter, loopOp, idx, remainderLoop, splitBound))) + return failure(); + + // Rewrite affine.min and affine.max ops. + Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx], + remainderIv = remainderLoop.getInductionVars()[idx]; + + rewriteAffineOpAfterPeeling( + rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); + rewriteAffineOpAfterPeeling( + rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); + + result = remainderLoop; + return success(); +} + void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -324,25 +324,25 @@ /// ``` /// %r = %step /// ``` -/// min/max operations inside the generated scf.if operation are rewritten in -/// a similar way. +/// min/max operations inside the partial iteration are rewritten in a similar +/// way. /// /// This function builds up a set of constraints, capable of proving that: /// * Inside the peeled loop: min(step, ub - iv) == step -/// * Inside the scf.if operation: min(step, ub - iv) == ub - iv +/// * Inside the partial iteration: min(step, ub - iv) == ub - iv /// /// Returns `success` if the given operation was replaced by a new operation; /// `failure` otherwise. /// /// Note: `ub` is the previous upper bound of the loop (before peeling). /// `insideLoop` must be true for min/max ops inside the loop and false for -/// affine.min ops inside the scf.for op. For an explanation of the other +/// affine.min ops inside the partial iteration. For an explanation of the other /// parameters, see comment of `canonicalizeMinMaxOpInLoop`. -static LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, - Operation *op, AffineMap map, - ValueRange operands, bool isMin, - Value iv, Value ub, Value step, - bool insideLoop) { +LogicalResult mlir::scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, + Operation *op, AffineMap map, + ValueRange operands, bool isMin, + Value iv, Value ub, Value step, + bool insideLoop) { FlatAffineValueConstraints constraints; constraints.appendDimId({iv, ub, step}); if (auto constUb = getConstantIntValue(ub)) @@ -374,14 +374,16 @@ rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, scf::IfOp ifOp, Value iv, Value splitBound, Value ub, Value step) { forOp.walk([&](OpTy affineOp) { - (void)rewritePeeledMinMaxOp(rewriter, affineOp, affineOp.getAffineMap(), - affineOp.operands(), IsMin, iv, ub, step, - /*insideLoop=*/true); + AffineMap map = affineOp.getAffineMap(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, + affineOp.operands(), IsMin, iv, ub, step, + /*insideLoop=*/true); }); ifOp.walk([&](OpTy affineOp) { - (void)rewritePeeledMinMaxOp(rewriter, affineOp, affineOp.getAffineMap(), - affineOp.operands(), IsMin, splitBound, ub, - step, /*insideLoop=*/false); + AffineMap map = affineOp.getAffineMap(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, + affineOp.operands(), IsMin, splitBound, ub, + step, /*insideLoop=*/false); }); } diff --git a/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir b/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tiled-loop-peeling.mlir @@ -0,0 +1,215 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2 +// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-tiled-loop-peeling=0,1,2 -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012 + +// CHECK-TILE-2-LABEL: func @tiled_loop_3d_tensor( +// CHECK-TILE-2-SAME: %[[input:.*]]: tensor, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-TILE-2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-TILE-2-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-TILE-2-DAG: %[[c2:.*]] = constant 2 : index +// CHECK-TILE-2: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]] +// CHECK-TILE-2: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]] +// CHECK-TILE-2: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]] +// CHECK-TILE-2: %[[init_tensor:.*]] = linalg.init_tensor +// CHECK-TILE-2: %[[split_bound:.*]] = affine.apply +// CHECK-TILE-2: %[[r1:.*]] = linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) +// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[split_bound]]) +// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) +// CHECK-TILE-2-SAME: ins (%[[loop_in1:.*]] = %[[input]]: tensor) +// CHECK-TILE-2-SAME: outs (%[[loop_out1:.*]] = %[[init_tensor]]: tensor) { +// CHECK-TILE-2: %[[min0_1:.*]] = affine.min +// CHECK-TILE-2: %[[min1_1:.*]] = affine.min +// CHECK-TILE-2: %[[in_slice1:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] +// CHECK-TILE-2: %[[out_slice1:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] +// CHECK-TILE-2: %[[mod_slice1:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] +// CHECK-TILE-2: linalg.yield %[[mod_slice1]] +// CHECK-TILE-2: %[[r2:.*]] = linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]]) +// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[dim2]]) +// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) +// CHECK-TILE-2-SAME: ins (%[[loop_in2:.*]] = %[[input]]: tensor) +// CHECK-TILE-2-SAME: outs (%[[loop_out2:.*]] = %[[r1]]: tensor) { +// CHECK-TILE-2: %[[min0_2:.*]] = affine.min +// CHECK-TILE-2: %[[min1_2:.*]] = affine.min +// CHECK-TILE-2: %[[apply2:.*]] = affine.apply +// CHECK-TILE-2: %[[in_slice2:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] +// CHECK-TILE-2: %[[out_slice2:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] +// CHECK-TILE-2: %[[mod_slice2:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] +// CHECK-TILE-2: linalg.yield %[[mod_slice2]] +// CHECK-TILE-2: return %[[r2]] + +// CHECK-TILE-012-LABEL: func @tiled_loop_3d_tensor +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012: linalg.tiled_loop {{.*}} { +// CHECK-TILE-012: linalg.yield +// CHECK-TILE-012: } +// CHECK-TILE-012-NOT: linalg.tiled_loop + +func @tiled_loop_3d_tensor(%arg0: tensor, %s0: index, %s1: index, + %s2: index) -> tensor { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor + %result = linalg.tiled_loop + (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2) + step (%s0, %s1, %s2) ins (%arg4 = %arg0: tensor) + outs (%arg5 = %output: tensor) { + %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0] + %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1] + %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2] + %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: tensor to tensor + %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor to tensor + %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor + %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor into tensor + linalg.yield %updated_slice : tensor + } + return %result : tensor +} + +// ----- + +// CHECK-TILE-2-LABEL: func @tiled_loop_3d_memref( +// CHECK-TILE-2-SAME: %[[input:.*]]: memref, %[[output:.*]]: memref, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-TILE-2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-TILE-2-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-TILE-2-DAG: %[[c2:.*]] = constant 2 : index +// CHECK-TILE-2: %[[dim0:.*]] = memref.dim %[[input]], %[[c0]] +// CHECK-TILE-2: %[[dim1:.*]] = memref.dim %[[input]], %[[c1]] +// CHECK-TILE-2: %[[dim2:.*]] = memref.dim %[[input]], %[[c2]] +// CHECK-TILE-2: %[[split_bound:.*]] = affine.apply +// CHECK-TILE-2: linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) +// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[split_bound]]) +// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) +// CHECK-TILE-2-SAME: ins (%[[loop_in1:.*]] = %[[input]]: memref) +// CHECK-TILE-2-SAME: outs (%[[loop_out1:.*]] = %[[output]]: memref) { +// CHECK-TILE-2: %[[min0_1:.*]] = affine.min +// CHECK-TILE-2: %[[min1_1:.*]] = affine.min +// CHECK-TILE-2: memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] +// CHECK-TILE-2: linalg.yield +// CHECK-TILE-2: linalg.tiled_loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]]) +// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[dim2]]) +// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) +// CHECK-TILE-2-SAME: ins (%[[loop_in2:.*]] = %[[input]]: memref) +// CHECK-TILE-2-SAME: outs (%[[loop_out2:.*]] = %[[output]]: memref) { +// CHECK-TILE-2: %[[min0_2:.*]] = affine.min +// CHECK-TILE-2: %[[min1_2:.*]] = affine.min +// CHECK-TILE-2: %[[apply2:.*]] = affine.apply +// CHECK-TILE-2: memref.subview %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] +// CHECK-TILE-2: linalg.yield +// CHECK-TILE-2: return + +// CHECK-TILE-012-LABEL: func @tiled_loop_3d_memref + +!memref_subview_type = type memref (d0 * s1 + s0 + d1 * s2 + d2)>> + +func @tiled_loop_3d_memref(%arg0: memref, %output: memref, + %s0: index, %s1: index, %s2: index) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %dim0 = memref.dim %arg0, %c0 : memref + %dim1 = memref.dim %arg0, %c1 : memref + %dim2 = memref.dim %arg0, %c2 : memref + linalg.tiled_loop + (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2) + step (%s0, %s1, %s2) ins (%arg4 = %arg0: memref) + outs (%arg5 = %output : memref) { + %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0] + %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1] + %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2] + %in_slice = memref.subview %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: memref to !memref_subview_type + "computation"(%in_slice) : (!memref_subview_type) -> memref + linalg.yield + } + return +} + +// ----- + +// CHECK-TILE-2-LABEL: func @step_1_do_not_peel +// CHECK-TILE-2: linalg.tiled_loop +// CHECK-TILE-2-NOT: linalg.tiled_loop + +// CHECK-TILE-012-LABEL: func @step_1_do_not_peel + +func @step_1_do_not_peel(%arg0: tensor) -> tensor { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor + %result = linalg.tiled_loop + (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2) + step (%c1, %c1, %c1) ins (%arg4 = %arg0: tensor) + outs (%arg5 = %output: tensor) { + %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor to tensor + %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor to tensor + %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor + %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor into tensor + linalg.yield %updated_slice : tensor + } + return %result : tensor +} + +// ----- + +// CHECK-TILE-2-LABEL: func @divides_evenly_do_not_peel +// CHECK-TILE-2: linalg.tiled_loop +// CHECK-TILE-2-NOT: linalg.tiled_loop + +// CHECK-TILE-012-LABEL: func @divides_evenly_do_not_peel + +func @divides_evenly_do_not_peel(%arg0: tensor, %s: index) + -> tensor { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %c64 = constant 64 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %output = linalg.init_tensor [%dim0, %dim1, %dim2] : tensor + %result = linalg.tiled_loop + (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %c64) + step (%s, %s, %c8) ins (%arg4 = %arg0: tensor) + outs (%arg5 = %output: tensor) { + %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor to tensor + %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor to tensor + %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor + %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor into tensor + linalg.yield %updated_slice : tensor + } + return %result : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::linalg; @@ -109,6 +110,10 @@ ListOption testInterchangePattern{ *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Test the interchange pattern.")}; + ListOption testTiledLoopPeeling{ + *this, "test-tiled-loop-peeling", + llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), + llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // end anonymous namespace @@ -575,6 +580,65 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); } +static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; + +namespace { +/// Peel TiledLoopOps, i.e., split them into two loops: One loop where the +/// `idx`-th loop contains only "full" iterations and a second loop for the +/// remaining partial iteration (if any). +struct TiledLoopPeelingPattern : public OpRewritePattern { + TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx) + : OpRewritePattern(ctx), idx(idx) {} + + LogicalResult matchAndRewrite(TiledLoopOp loopOp, + PatternRewriter &rewriter) const override { + SmallVector peeledLoops; + if (loopOp->hasAttr(kPeeledLoopsLabel)) { + auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast(); + peeledLoops = + llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { + return attr.cast().getInt(); + })); + // Check if the loop was already peeled. + if (llvm::find(peeledLoops, idx) != peeledLoops.end()) + return failure(); + } + + if (static_cast(loopOp.iterator_types().size()) <= idx) + return failure(); + + // Peel loop and canonicalize. + TiledLoopOp result; + if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, + result))) + return failure(); + peeledLoops.push_back(idx); + + // Apply label, so that the same loop is not rewritten a second time. + rewriter.updateRootInPlace(loopOp, [&]() { + loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); + }); + result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); + return success(); + } + + /// Index of loop to peel. + int64_t idx; +}; +} // namespace + +static void applyTiledLoopPeelingPattern(FuncOp funcOp, + ArrayRef loops) { + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + for (unsigned idx : loops) + patterns.add(ctx, idx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + + // Drop the marker. + funcOp.walk([](TiledLoopOp op) { op->removeAttr(kPeeledLoopsLabel); }); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -612,6 +676,8 @@ return applyGeneralizePadTensorPatterns(getFunction()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getFunction()); + if (testTiledLoopPeeling.hasValue()) + return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling); if (testTileAndPadPattern) return applyTileAndPadPattern(getFunction(), tileSizesForPadding); if (testHoistPadding) {