Index: mlir/include/mlir/Dialect/SCF/Passes.h =================================================================== --- mlir/include/mlir/Dialect/SCF/Passes.h +++ mlir/include/mlir/Dialect/SCF/Passes.h @@ -33,7 +33,8 @@ /// Creates a pass which tiles innermost parallel loops. std::unique_ptr -createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}, + bool withInboundCheck = false); /// Creates a pass which folds arith ops on induction variable into /// loop range. Index: mlir/include/mlir/Dialect/SCF/Passes.td =================================================================== --- mlir/include/mlir/Dialect/SCF/Passes.td +++ mlir/include/mlir/Dialect/SCF/Passes.td @@ -40,7 +40,11 @@ let options = [ ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t", "Factors to tile parallel loops by", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + Option<"withInboundCheck", "with-inbound-check", "bool", + /*default=*/"false", + "Perform tiling with fixed upperbound with inbound check " + "for internal loops"> ]; let dependentDialects = ["AffineDialect"]; } Index: mlir/include/mlir/Dialect/SCF/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/SCF/Transforms.h +++ mlir/include/mlir/Dialect/SCF/Transforms.h @@ -49,7 +49,8 @@ /// The function returns the resulting ParallelOps, i.e. {outer_loop_op, /// inner_loop_op}. std::pair -tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes); +tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes, + bool withInboundCheck); /// Populates patterns for SCF structural type conversions and sets up the /// provided ConversionTarget with the appropriate legality configuration for Index: mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp =================================================================== --- mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "llvm/ADT/Sequence.h" using namespace mlir; using namespace mlir::scf; @@ -33,12 +34,25 @@ /// min(%arg5*tileSize[1], %arg3-%i1)) /// step (%arg4, %arg5) /// +/// or, when with-inbound-check is true, into +/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4*tileSize[0], +/// %arg5*tileSize[1]) +/// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], +/// %arg5*tileSize[1]) +/// step (%arg4, %arg5) +/// %inbound = (%j0 * %arg4 + %i0 < %arg2) && +/// (%j1 * %arg5 + %i1 < %arg3) +/// scf.if (%inbound) +/// .... +/// /// where the uses of %i0 and %i1 in the loop body are replaced by /// %i0 + j0 and %i1 + %j1. // /// The old loop is replaced with the new one. std::pair -mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { +mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, + bool withInboundCheck) { OpBuilder b(op); auto zero = b.create(op.getLoc(), 0); SmallVector tileSizeConstants; @@ -76,6 +90,7 @@ // Create the inner loop with adjusted bounds. SmallVector newBounds; newBounds.reserve(op.upperBound().size()); + bool needInboundCheck = false; for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(), outerLoop.step(), outerLoop.getInductionVars(), op.step(), tileSizeConstants)) { @@ -101,6 +116,14 @@ continue; } } + + // For InboundCheck mode, just use the variable outer step + if (withInboundCheck) { + newBounds.push_back(newStep); + needInboundCheck = true; + continue; + } + // Otherwise, we dynamically compute the bound for // each iteration of the outer loop. newBounds.push_back( @@ -111,17 +134,51 @@ op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, op.step()); - // Steal the body of the old parallel loop and erase it. - innerLoop.region().takeBody(op.region()); - - // Insert computation for new index vectors and replace uses. - b.setInsertionPointToStart(innerLoop.getBody()); - for (auto ivs : - llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { - Value inner_index = std::get<0>(ivs); - AddIOp newIndex = - b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); - inner_index.replaceAllUsesExcept(newIndex, newIndex); + if (withInboundCheck && needInboundCheck) { + b.setInsertionPointToStart(innerLoop.getBody()); + // Insert in-bound check + Value inbound = + b.create(op.getLoc(), b.getIntegerType(1), + b.getIntegerAttr(b.getIntegerType(1), 1)); + for (auto dim : + llvm::zip(outerLoop.upperBound(), outerLoop.getInductionVars(), + innerLoop.getInductionVars(), innerLoop.step())) { + Value outerUpperBound, outerIV, innerIV, innerStep; + std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim; + // %in_bound = %in_bound && + // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) + Value index = b.create( + op.getLoc(), b.create(op.getLoc(), innerIV, innerStep), + outerIV); + Value dimInbound = b.create(op.getLoc(), CmpIPredicate::ult, + index, outerUpperBound); + inbound = b.create(op.getLoc(), inbound, dimInbound); + } + auto ifInbound = b.create(op.getLoc(), + /*resultTypes*/ ArrayRef{}, inbound, + /*hasElseRegion*/ false); + ifInbound.thenRegion().takeBody(op.region()); + Block &thenBlock = ifInbound.thenRegion().front(); + b.setInsertionPointToStart(innerLoop.getBody()); + for (auto ivs : llvm::enumerate(llvm::zip(innerLoop.getInductionVars(), + outerLoop.getInductionVars()))) { + AddIOp newIndex = b.create(op.getLoc(), std::get<0>(ivs.value()), + std::get<1>(ivs.value())); + thenBlock.getArgument(ivs.index()) + .replaceAllUsesExcept(newIndex, newIndex); + } + thenBlock.eraseArguments(llvm::to_vector<4>( + llvm::seq((unsigned)0, thenBlock.getNumArguments()))); + } else { + innerLoop.region().takeBody(op.region()); + b.setInsertionPointToStart(innerLoop.getBody()); + for (auto ivs : llvm::zip(innerLoop.getInductionVars(), + outerLoop.getInductionVars())) { + Value inner_index = std::get<0>(ivs); + AddIOp newIndex = + b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); + inner_index.replaceAllUsesExcept(newIndex, newIndex); + } } op.erase(); @@ -132,8 +189,10 @@ struct ParallelLoopTiling : public SCFParallelLoopTilingBase { ParallelLoopTiling() = default; - explicit ParallelLoopTiling(ArrayRef tileSizes) { + explicit ParallelLoopTiling(ArrayRef tileSizes, + bool withInboundCheck = false) { this->tileSizes = tileSizes; + this->withInboundCheck = withInboundCheck; } void runOnFunction() override { @@ -142,13 +201,14 @@ for (ParallelOp ploop : innermostPloops) { // FIXME: Add reduction support. if (ploop.getNumReductions() == 0) - tileParallelLoop(ploop, tileSizes); + tileParallelLoop(ploop, tileSizes, withInboundCheck); } } }; } // namespace std::unique_ptr -mlir::createParallelLoopTilingPass(ArrayRef tileSizes) { - return std::make_unique(tileSizes); +mlir::createParallelLoopTilingPass(ArrayRef tileSizes, + bool withInboundCheck) { + return std::make_unique(tileSizes, withInboundCheck); } Index: mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-tiling{parallel-loop-tile-sizes=1,4 with-inbound-check=true})' -split-input-file | FileCheck %s + +func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, + %A: memref, %B: memref, + %C: memref, %result: memref) { + scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + %B_elem = memref.load %B[%i0, %i1] : memref + %C_elem = memref.load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + memref.store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref, [[ARG8:%.*]]: memref, [[ARG9:%.*]]: memref, [[ARG10:%.*]]: memref) { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[ARG5]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[ARG6]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) { +// CHECK: scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V1]], [[V2]]) step ([[ARG5]], [[ARG6]]) { +// CHECK: [[V9:%.*]] = addi [[V7]], [[V3]] : index +// CHECK: [[V10:%.*]] = addi [[V8]], [[V4]] : index +// CHECK: %true = constant true +// CHECK: [[V11:%.*]] = muli [[V7]], [[ARG5]] : index +// CHECK: [[V12:%.*]] = addi [[V11]], [[V3]] : index +// CHECK: [[V13:%.*]] = cmpi ult, [[V12]], [[ARG3]] : index +// CHECK: [[V14:%.*]] = and %true, [[V13]] : i1 +// CHECK: [[V15:%.*]] = muli [[V8]], [[ARG6]] : index +// CHECK: [[V16:%.*]] = addi [[V15]], [[V4]] : index +// CHECK: [[V17:%.*]] = cmpi ult, [[V16]], [[ARG4]] : index +// CHECK: [[V18:%.*]] = and [[V14]], [[V17]] : i1 +// CHECK: scf.if [[V18]] { +// CHECK: [[V19:%.*]] = memref.load [[ARG8]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V20:%.*]] = memref.load [[ARG9]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V21:%.*]] = addf [[V19]], [[V20]] : f32 +// CHECK: memref.store [[V21]], [[ARG10]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return + +// ----- + +func @static_loop_with_step() { + %c0 = constant 0 : index + %c3 = constant 3 : index + %c22 = constant 22 : index + %c24 = constant 24 : index + scf.parallel (%i0, %i1) = (%c0, %c0) to (%c22, %c24) step (%c3, %c3) { + } + return +} + +// CHECK-LABEL: func @static_loop_with_step() { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: [[C22:%.*]] = constant 22 : index +// CHECK: [[C24:%.*]] = constant 24 : index +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[C3]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[C3]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[C0]], [[C0]]) to ([[C22]], [[C24]]) step ([[V1]], [[V2]]) { +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V1]], [[V2]]) step ([[C3]], [[C3]]) { +// CHECK-NOT: scf.if +// CHECK: = addi [[V5]], [[V3]] : index +// CHECK: = addi [[V6]], [[V4]] : index +// CHECK: } +// CHECK: } +// CHECK: return + +// ----- + +func @tile_nested_innermost() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + return +} + +// CHECK-LABEL: func @tile_nested_innermost() { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: scf.parallel ([[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1_1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V3:%.*]] = muli [[C1]], [[C1_1]] : index +// CHECK: [[V4:%.*]] = muli [[C1]], [[C4]] : index +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) { +// CHECK: scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V4]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V8]], [[V5]] : index +// CHECK: = addi [[V9]], [[V6]] : index +// CHECK: scf.if +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: [[C0_2:%.*]] = constant 0 : index +// CHECK: [[C1_2:%.*]] = constant 1 : index +// CHECK: [[C4_1:%.*]] = constant 4 : index +// CHECK: [[V10:%.*]] = muli [[C1]], [[C1_2]] : index +// CHECK: [[V11:%.*]] = muli [[C1]], [[C4_1]] : index +// CHECK: scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) { +// CHECK: scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V11]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V15]], [[V12]] : index +// CHECK: = addi [[V16]], [[V13]] : index +// CHECK: scf.if +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +func @tile_nested_in_non_ploop() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + } + return +} + +// CHECK-LABEL: func @tile_nested_in_non_ploop +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.parallel +// CHECK: scf.parallel +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: }