diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" @@ -120,16 +121,69 @@ struct ParallelComputeFunctionType { FunctionType type; - llvm::SmallVector captures; + SmallVector captures; +}; + +// Helper struct to parse parallel compute function argument list. +struct ParallelComputeFunctionArgs { + BlockArgument blockIndex(); + BlockArgument blockSize(); + ArrayRef tripCounts(); + ArrayRef lowerBounds(); + ArrayRef upperBounds(); + ArrayRef steps(); + ArrayRef captures(); + + unsigned numLoops; + ArrayRef args; +}; + +struct ParallelComputeFunctionBounds { + SmallVector tripCounts; + SmallVector lowerBounds; + SmallVector upperBounds; + SmallVector steps; }; struct ParallelComputeFunction { + unsigned numLoops; FuncOp func; llvm::SmallVector captures; }; } // namespace +BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; } +BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; } + +ArrayRef ParallelComputeFunctionArgs::tripCounts() { + return args.drop_front(2).take_front(numLoops); +} + +ArrayRef ParallelComputeFunctionArgs::lowerBounds() { + return args.drop_front(2 + 1 * numLoops).take_front(numLoops); +} + +ArrayRef ParallelComputeFunctionArgs::upperBounds() { + return args.drop_front(2 + 2 * numLoops).take_front(numLoops); +} + +ArrayRef ParallelComputeFunctionArgs::steps() { + return args.drop_front(2 + 3 * numLoops).take_front(numLoops); +} + +ArrayRef ParallelComputeFunctionArgs::captures() { + return args.drop_front(2 + 4 * numLoops); +} + +template +static SmallVector integerConstants(ValueRange values) { + SmallVector attrs(values.size()); + for (unsigned i = 0; i < values.size(); ++i) + matchPattern(values[i], m_Constant(&attrs[i])); + return attrs; +} + // Converts one-dimensional iteration index in the [0, tripCount) interval // into multidimensional iteration coordinate. static SmallVector delinearize(ImplicitLocOpBuilder &b, Value index, @@ -154,7 +208,7 @@ llvm::SetVector captures; getUsedValuesDefinedAbove(op.region(), op.region(), captures); - llvm::SmallVector inputs; + SmallVector inputs; inputs.reserve(2 + 4 * op.getNumLoops() + captures.size()); Type indexTy = rewriter.getIndexType(); @@ -167,7 +221,9 @@ for (unsigned i = 0; i < op.getNumLoops(); ++i) inputs.push_back(indexTy); // loop tripCount - // Parallel operation lower bound, upper bound and step. + // Parallel operation lower bound, upper bound and step. Lower bound, upper + // bound and step passed as contiguous arguments: + // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...) for (unsigned i = 0; i < op.getNumLoops(); ++i) { inputs.push_back(indexTy); // lower bound inputs.push_back(indexTy); // upper bound @@ -185,16 +241,14 @@ // Create a parallel compute fuction from the parallel operation. static ParallelComputeFunction -createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) { +createParallelComputeFunction(scf::ParallelOp op, + ParallelComputeFunctionBounds bounds, + PatternRewriter &rewriter) { OpBuilder::InsertionGuard guard(rewriter); ImplicitLocOpBuilder b(op.getLoc(), rewriter); ModuleOp module = op->getParentOfType(); - // Make sure that all constants will be inside the parallel operation body to - // reduce the number of parallel compute function arguments. - cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); - ParallelComputeFunctionType computeFuncType = getParallelComputeFunctionType(op, rewriter); @@ -211,27 +265,35 @@ Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs()); b.setInsertionPointToEnd(block); - unsigned offset = 0; // argument offset for arguments decoding - - // Returns `numArguments` arguments starting from `offset` and updates offset - // by moving forward to the next argument. - auto getArguments = [&](unsigned numArguments) -> ArrayRef { - auto args = block->getArguments(); - auto slice = args.drop_front(offset).take_front(numArguments); - offset += numArguments; - return {slice.begin(), slice.end()}; - }; + ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()}; // Block iteration position defined by the block index and size. - Value blockIndex = block->getArgument(offset++); - Value blockSize = block->getArgument(offset++); + BlockArgument blockIndex = args.blockIndex(); + BlockArgument blockSize = args.blockSize(); // Constants used below. Value c0 = b.create(0); Value c1 = b.create(1); + // Materialize known constants as constant operation in the function body. + auto values = [&](ArrayRef args, ArrayRef attrs) { + return llvm::to_vector( + llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { + if (IntegerAttr attr = std::get<1>(tuple)) + return b.create(attr); + return std::get<0>(tuple); + })); + }; + // Multi-dimensional parallel iteration space defined by the loop trip counts. - ArrayRef tripCounts = getArguments(op.getNumLoops()); + auto tripCounts = values(args.tripCounts(), bounds.tripCounts); + + // Parallel operation lower bound and step. + auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds); + auto steps = values(args.steps(), bounds.steps); + + // Remaining arguments are implicit captures of the parallel operation. + ArrayRef captures = args.captures(); // Compute a product of trip counts to get the size of the flattened // one-dimensional iteration space. @@ -239,14 +301,6 @@ for (unsigned i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); - // Parallel operation lower bound and step. - ArrayRef lowerBound = getArguments(op.getNumLoops()); - offset += op.getNumLoops(); // skip upper bound arguments - ArrayRef step = getArguments(op.getNumLoops()); - - // Remaining arguments are implicit captures of the parallel operation. - ArrayRef captures = getArguments(block->getNumArguments() - offset); - // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: // blockFirstIndex = blockIndex * blockSize Value blockFirstIndex = b.create(blockIndex, blockSize); @@ -312,7 +366,7 @@ // Compute induction variable for `loopIdx`. computeBlockInductionVars[loopIdx] = nb.create( - lowerBound[loopIdx], nb.create(iv, step[loopIdx])); + lowerBounds[loopIdx], nb.create(iv, steps[loopIdx])); // Check if we are inside first or last iteration of the loop. isBlockFirstCoord[loopIdx] = nb.create( @@ -359,7 +413,7 @@ workLoopBuilder(0)); b.create(ValueRange()); - return {func, std::move(computeFuncType.captures)}; + return {op.getNumLoops(), func, std::move(computeFuncType.captures)}; } // Creates recursive async dispatch function for the given parallel compute @@ -640,6 +694,10 @@ ImplicitLocOpBuilder b(op.getLoc(), rewriter); + // Make sure that all constants will be inside the parallel operation body to + // reduce the number of parallel compute function arguments. + cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); + // Compute trip count for each loop induction variable: // tripCount = ceil_div(upperBound - lowerBound, step); SmallVector tripCounts(op.getNumLoops()); @@ -647,8 +705,8 @@ auto lb = op.lowerBound()[i]; auto ub = op.upperBound()[i]; auto step = op.step()[i]; - auto range = b.create(ub, lb); - tripCounts[i] = b.create(range, step); + auto range = b.createOrFold(ub, lb); + tripCounts[i] = b.createOrFold(range, step); } // Compute a product of trip counts to get the 1-dimensional iteration space @@ -699,10 +757,22 @@ Value blockSize = b.create(tripCount, bs1); Value blockCount = b.create(tripCount, blockSize); + // Collect statically known constants defining the loop nest in the parallel + // compute function. LLVM can't always push constants across the non-trivial + // async dispatch call graph, by providing these values explicitly we can + // choose to build more efficient loop nest, and rely on a better constant + // folding, loop unrolling and vectorization. + ParallelComputeFunctionBounds staticBounds = { + integerConstants(tripCounts), + integerConstants(op.lowerBound()), + integerConstants(op.upperBound()), + integerConstants(op.step()), + }; + // Create a parallel compute function that takes a block id and computes the // parallel operation body for a subset of iteration space. ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, rewriter); + createParallelComputeFunction(op, staticBounds, rewriter); // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt %s \ +// RUN: mlir-opt %s -split-input-file \ // RUN: -async-parallel-for=async-dispatch=true \ // RUN: | FileCheck %s -// RUN: mlir-opt %s \ +// RUN: mlir-opt %s -split-input-file \ // RUN: -async-parallel-for=async-dispatch=false \ // RUN: -canonicalize -inline -symbol-dce \ // RUN: | FileCheck %s @@ -34,3 +34,35 @@ // CHECK: %[[CST:.*]] = arith.constant 1.0{{.*}} : f32 // CHECK: scf.for // CHECK: memref.store %[[CST]], %[[MEMREF]] + +// ----- + +// Check that constant loop bound sunk into the parallel compute function. + +// CHECK-LABEL: func @sink_constant_step( +func @sink_constant_step(%arg0: memref, %lb: index, %ub: index) { + %one = arith.constant 1.0 : f32 + %st = arith.constant 123 : index + + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + memref.store %one, %arg0[%i] : memref + } + + return +} + +// CHECK-LABEL: func private @parallel_compute_fn( +// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, +// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP:arg[0-9]+]]: index, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref +// CHECK-SAME: ) { +// CHECK: %[[CSTEP:.*]] = arith.constant 123 : index +// CHECK-NOT: %[[STEP]] +// CHECK: scf.for %[[I:arg[0-9]+]] +// CHECK: %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]] +// CHECK: %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]] +// CHECK: memref.store \ No newline at end of file