diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -16,8 +16,6 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/Support/CommandLine.h" using namespace mlir; using namespace mlir::scf; @@ -30,8 +28,8 @@ /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) /// step (%arg4*tileSize[0], /// %arg5*tileSize[1]) -/// scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%i0) -/// min(tileSize[1], %arg3-%i1)) +/// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) +/// min(%arg5*tileSize[1], %arg3-%i1)) /// step (%arg4, %arg5) /// /// where the uses of %i0 and %i1 in the loop body are replaced by @@ -76,12 +74,36 @@ // Create the inner loop with adjusted bounds. SmallVector newBounds; newBounds.reserve(op.upperBound().size()); - for (auto bounds : llvm::zip(tileSizeConstants, outerLoop.upperBound(), - outerLoop.getInductionVars())) { - newBounds.push_back(b.create( - op.getLoc(), b.getIndexType(), minMap, - ValueRange{std::get<0>(bounds), std::get<1>(bounds), - std::get<2>(bounds)})); + for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(), + outerLoop.step(), outerLoop.getInductionVars(), + op.step(), tileSizeConstants)) { + Value lowerBound, upperBound, newStep, iv, step, tileSizeConstant; + std::tie(lowerBound, upperBound, newStep, iv, step, tileSizeConstant) = dim; + // Collect the statically known loop bounds + auto lowerBoundConstant = + dyn_cast_or_null(lowerBound.getDefiningOp()); + auto upperBoundConstant = + dyn_cast_or_null(upperBound.getDefiningOp()); + auto stepConstant = dyn_cast_or_null(step.getDefiningOp()); + auto tileSize = + cast(tileSizeConstant.getDefiningOp()).getValue(); + // If the loop bounds and the loop step are constant and if the number of + // loop iterations is an integer multiple of the tile size, we use a static + // bound for the inner loop. + if (lowerBoundConstant && upperBoundConstant && stepConstant) { + auto numIterations = llvm::divideCeil(upperBoundConstant.getValue() - + lowerBoundConstant.getValue(), + stepConstant.getValue()); + if (numIterations % tileSize == 0) { + newBounds.push_back(newStep); + continue; + } + } + // Otherwise, we dynamically compute the bound for + // each iteration of the outer loop. + newBounds.push_back( + b.create(op.getLoc(), b.getIndexType(), minMap, + ValueRange{newStep, upperBound, iv})); } auto innerLoop = b.create( op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, @@ -104,8 +126,8 @@ op.erase(); } -/// Get a list of most nested parallel loops. Assumes that ParallelOps are only -/// directly nested. +/// Get a list of most nested parallel loops. Assumes that ParallelOps are +/// only directly nested. static bool getInnermostNestedLoops(Block *block, SmallVectorImpl &loops) { bool hasInnerLoop = false; @@ -131,7 +153,9 @@ getInnermostNestedLoops(&block, mostNestedParallelOps); } for (ParallelOp pLoop : mostNestedParallelOps) { - tileParallelLoop(pLoop, tileSizes); + // FIXME: Add reduction support. + if (pLoop.getNumReductions() == 0) + tileParallelLoop(pLoop, tileSizes); } } }; diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir @@ -15,22 +15,52 @@ // CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> // CHECK-LABEL: func @parallel_loop( -// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index, [[VAL_5:%.*]]: index, [[VAL_6:%.*]]: memref, [[VAL_7:%.*]]: memref, [[VAL_8:%.*]]: memref, [[VAL_9:%.*]]: memref) { -// CHECK: [[VAL_10:%.*]] = constant 0 : index -// CHECK: [[VAL_11:%.*]] = constant 1 : index -// CHECK: [[VAL_12:%.*]] = constant 4 : index -// CHECK: [[VAL_13:%.*]] = muli [[VAL_4]], [[VAL_11]] : index -// CHECK: [[VAL_14:%.*]] = muli [[VAL_5]], [[VAL_12]] : index -// CHECK: scf.parallel ([[VAL_15:%.*]], [[VAL_16:%.*]]) = ([[VAL_0]], [[VAL_1]]) to ([[VAL_2]], [[VAL_3]]) step ([[VAL_13]], [[VAL_14]]) { -// CHECK: [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]]) -// CHECK: [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]]) -// CHECK: scf.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) { -// CHECK: [[VAL_21:%.*]] = addi [[VAL_19]], [[VAL_15]] : index -// CHECK: [[VAL_22:%.*]] = addi [[VAL_20]], [[VAL_16]] : index -// CHECK: [[VAL_23:%.*]] = load [[VAL_7]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref -// CHECK: [[VAL_24:%.*]] = load [[VAL_8]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref -// CHECK: [[VAL_25:%.*]] = addf [[VAL_23]], [[VAL_24]] : f32 -// CHECK: store [[VAL_25]], [[VAL_9]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref +// CHECK-SAME: [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref, [[ARG8:%.*]]: memref, [[ARG9:%.*]]: memref, [[ARG10:%.*]]: memref) { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[ARG5]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[ARG6]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) { +// CHECK: [[V5:%.*]] = affine.min #map0([[V1]], [[ARG3]], [[V3]]) +// CHECK: [[V6:%.*]] = affine.min #map0([[V2]], [[ARG4]], [[V4]]) +// CHECK: scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V5]], [[V6]]) step ([[ARG5]], [[ARG6]]) { +// CHECK: [[V9:%.*]] = addi [[V7]], [[V3]] : index +// CHECK: [[V10:%.*]] = addi [[V8]], [[V4]] : index +// CHECK: [[V11:%.*]] = load [[ARG8]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V12:%.*]] = load [[ARG9]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V13:%.*]] = addf [[V11]], [[V12]] : f32 +// CHECK: store [[V13]], [[ARG10]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: } +// CHECK: } +// CHECK: return + +// ----- + +func @static_loop_with_step() { + %c0 = constant 0 : index + %c3 = constant 3 : index + %c22 = constant 22 : index + %c24 = constant 24 : index + scf.parallel (%i0, %i1) = (%c0, %c0) to (%c22, %c24) step (%c3, %c3) { + } + return +} + +// CHECK-LABEL: func @static_loop_with_step() { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: [[C22:%.*]] = constant 22 : index +// CHECK: [[C24:%.*]] = constant 24 : index +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[C3]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[C3]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[C0]], [[C0]]) to ([[C22]], [[C24]]) step ([[V1]], [[V2]]) { +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V1]], [[V2]]) step ([[C3]], [[C3]]) { +// CHECK: = addi [[V5]], [[V3]] : index +// CHECK: = addi [[V6]], [[V4]] : index // CHECK: } // CHECK: } // CHECK: return @@ -51,31 +81,33 @@ } // CHECK-LABEL: func @tile_nested_innermost() { -// CHECK: [[VAL_24:%.*]] = constant 2 : index -// CHECK: [[VAL_25:%.*]] = constant 0 : index -// CHECK: [[VAL_26:%.*]] = constant 1 : index -// CHECK: scf.parallel ([[VAL_27:%.*]], [[VAL_28:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_26]], [[VAL_26]]) { -// CHECK: [[VAL_29:%.*]] = constant 0 : index -// CHECK: [[VAL_30:%.*]] = constant 1 : index -// CHECK: [[VAL_31:%.*]] = constant 4 : index -// CHECK: [[VAL_32:%.*]] = muli [[VAL_26]], [[VAL_30]] : index -// CHECK: [[VAL_33:%.*]] = muli [[VAL_26]], [[VAL_31]] : index -// CHECK: scf.parallel ([[VAL_34:%.*]], [[VAL_35:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_32]], [[VAL_33]]) { -// CHECK: [[VAL_36:%.*]] = affine.min #map0([[VAL_30]], [[VAL_24]], [[VAL_34]]) -// CHECK: [[VAL_37:%.*]] = affine.min #map0([[VAL_31]], [[VAL_24]], [[VAL_35]]) -// CHECK: scf.parallel ([[VAL_38:%.*]], [[VAL_39:%.*]]) = ([[VAL_29]], [[VAL_29]]) to ([[VAL_36]], [[VAL_37]]) step ([[VAL_26]], [[VAL_26]]) { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: scf.parallel ([[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1_1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V3:%.*]] = muli [[C1]], [[C1_1]] : index +// CHECK: [[V4:%.*]] = muli [[C1]], [[C4]] : index +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) { +// CHECK: [[V7:%.*]] = affine.min #map0([[V4]], [[C2]], [[V6]]) +// CHECK: scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V7]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V8]], [[V5]] : index +// CHECK: = addi [[V9]], [[V6]] : index // CHECK: } // CHECK: } // CHECK: } -// CHECK: [[VAL_40:%.*]] = constant 0 : index -// CHECK: [[VAL_41:%.*]] = constant 1 : index -// CHECK: [[VAL_42:%.*]] = constant 4 : index -// CHECK: [[VAL_43:%.*]] = muli [[VAL_26]], [[VAL_41]] : index -// CHECK: [[VAL_44:%.*]] = muli [[VAL_26]], [[VAL_42]] : index -// CHECK: scf.parallel ([[VAL_45:%.*]], [[VAL_46:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_43]], [[VAL_44]]) { -// CHECK: [[VAL_47:%.*]] = affine.min #map0([[VAL_41]], [[VAL_24]], [[VAL_45]]) -// CHECK: [[VAL_48:%.*]] = affine.min #map0([[VAL_42]], [[VAL_24]], [[VAL_46]]) -// CHECK: scf.parallel ([[VAL_49:%.*]], [[VAL_50:%.*]]) = ([[VAL_40]], [[VAL_40]]) to ([[VAL_47]], [[VAL_48]]) step ([[VAL_26]], [[VAL_26]]) { +// CHECK: [[C0_2:%.*]] = constant 0 : index +// CHECK: [[C1_2:%.*]] = constant 1 : index +// CHECK: [[C4_1:%.*]] = constant 4 : index +// CHECK: [[V10:%.*]] = muli [[C1]], [[C1_2]] : index +// CHECK: [[V11:%.*]] = muli [[C1]], [[C4_1]] : index +// CHECK: scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) { +// CHECK: [[V14:%.*]] = affine.min #map0([[V11]], [[C2]], [[V13]]) +// CHECK: scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V14]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V15]], [[V12]] : index +// CHECK: = addi [[V16]], [[V13]] : index // CHECK: } // CHECK: } // CHECK: return