diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" @@ -118,22 +119,126 @@ int32_t minTaskSize; }; -struct ParallelComputeFunctionType { - FunctionType type; - llvm::SmallVector captures; +class ParallelComputeFunctionType { +public: + ParallelComputeFunctionType(FunctionType type, SmallVector captures, + unsigned numLoops) + : functionType(type), implicitCaptures(std::move(captures)), + numLoops(numLoops) {} + + FunctionType type() const { return functionType; } + const SmallVector &captures() const { return implicitCaptures; } + + // Helper functions to parse the parallel compute function arguments. + BlockArgument blockIndex(ArrayRef args) const; + BlockArgument blockSize(ArrayRef args) const; + ArrayRef tripCounts(ArrayRef args) const; + ArrayRef lowerBounds(ArrayRef args) const; + ArrayRef upperBounds(ArrayRef args) const; + ArrayRef steps(ArrayRef args) const; + ArrayRef captures(ArrayRef args) const; + +private: + FunctionType functionType; // signature of the compute function + SmallVector implicitCaptures; // values captured by the scf.parallel + unsigned numLoops; // depth of the scf.for loop nest }; -struct ParallelComputeFunction { - FuncOp func; - llvm::SmallVector captures; +class ParallelComputeFunction { +public: + ParallelComputeFunction(ParallelComputeFunctionType type, FuncOp func) + : functionType(std::move(type)), function(func) {} + + FuncOp func() const { return function; } + const SmallVector &captures() const { return functionType.captures(); } + + // Helper functions to replace block arguments with statically known constants + // in the function body. We rely un later dead argument elimination to cleanup + // unused arguments. + void setTripCounts(ArrayRef attrs); + void setLowerBounds(ArrayRef attrs); + void setUpperBounds(ArrayRef attrs); + void setSteps(ArrayRef attrs); + +private: + ArrayRef args() { return function.getArguments(); } + + ParallelComputeFunctionType functionType; + FuncOp function; }; } // namespace +BlockArgument +ParallelComputeFunctionType::blockIndex(ArrayRef args) const { + return args[0]; +} + +BlockArgument +ParallelComputeFunctionType::blockSize(ArrayRef args) const { + return args[1]; +} + +ArrayRef +ParallelComputeFunctionType::tripCounts(ArrayRef args) const { + return args.drop_front(2).take_front(numLoops); +} + +ArrayRef +ParallelComputeFunctionType::lowerBounds(ArrayRef args) const { + return args.drop_front(2 + 1 * numLoops).take_front(numLoops); +} + +ArrayRef +ParallelComputeFunctionType::upperBounds(ArrayRef args) const { + return args.drop_front(2 + 2 * numLoops).take_front(numLoops); +} + +ArrayRef +ParallelComputeFunctionType::steps(ArrayRef args) const { + return args.drop_front(2 + 3 * numLoops).take_front(numLoops); +} + +ArrayRef +ParallelComputeFunctionType::captures(ArrayRef args) const { + return args.drop_front(2 + 4 * numLoops); +} + +// Replace all uses of `args` with constants created in the function body from +// the `attrs` (if attribute is not null). +static void replaceWithConstants(FuncOp func, ArrayRef args, + ArrayRef attrs) { + // Create constants at the beginning of the entry block. + OpBuilder b = OpBuilder::atBlockBegin(&func.body().front()); + assert(args.size() == attrs.size() && "illegal `args` and `attrs` sizes"); + + for (auto tuple : llvm::zip(args, attrs)) { + Value value = std::get<0>(tuple); + if (IntegerAttr attr = std::get<1>(tuple)) + value.replaceAllUsesWith(b.create(func.getLoc(), attr)); + } +} + +void ParallelComputeFunction::setTripCounts(ArrayRef attrs) { + replaceWithConstants(function, functionType.tripCounts(args()), attrs); +} + +void ParallelComputeFunction::setLowerBounds(ArrayRef attrs) { + replaceWithConstants(function, functionType.lowerBounds(args()), attrs); +} + +void ParallelComputeFunction::setUpperBounds(ArrayRef attrs) { + replaceWithConstants(function, functionType.upperBounds(args()), attrs); +} + +void ParallelComputeFunction::setSteps(ArrayRef attrs) { + replaceWithConstants(function, functionType.steps(args()), attrs); +} + // Converts one-dimensional iteration index in the [0, tripCount) interval // into multidimensional iteration coordinate. static SmallVector delinearize(ImplicitLocOpBuilder &b, Value index, - ArrayRef tripCounts) { + ArrayRef tripCounts) { SmallVector coords(tripCounts.size()); assert(!tripCounts.empty() && "tripCounts must be not empty"); @@ -154,7 +259,7 @@ llvm::SetVector captures; getUsedValuesDefinedAbove(op.region(), op.region(), captures); - llvm::SmallVector inputs; + SmallVector inputs; inputs.reserve(2 + 4 * op.getNumLoops() + captures.size()); Type indexTy = rewriter.getIndexType(); @@ -167,7 +272,9 @@ for (unsigned i = 0; i < op.getNumLoops(); ++i) inputs.push_back(indexTy); // loop tripCount - // Parallel operation lower bound, upper bound and step. + // Parallel operation lower bound, upper bound and step. Lower bound, upper + // bound and step passed as contiguous arguments: + // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...) for (unsigned i = 0; i < op.getNumLoops(); ++i) { inputs.push_back(indexTy); // lower bound inputs.push_back(indexTy); // upper bound @@ -180,7 +287,8 @@ // Convert captures to vector for later convenience. SmallVector capturesVector(captures.begin(), captures.end()); - return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector}; + return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector, + op.getNumLoops()}; } // Create a parallel compute fuction from the parallel operation. @@ -191,14 +299,10 @@ ModuleOp module = op->getParentOfType(); - // Make sure that all constants will be inside the parallel operation body to - // reduce the number of parallel compute function arguments. - cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); - ParallelComputeFunctionType computeFuncType = getParallelComputeFunctionType(op, rewriter); - FunctionType type = computeFuncType.type; + FunctionType type = computeFuncType.type(); FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type); func.setPrivate(); @@ -211,27 +315,18 @@ Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs()); b.setInsertionPointToEnd(block); - unsigned offset = 0; // argument offset for arguments decoding - - // Returns `numArguments` arguments starting from `offset` and updates offset - // by moving forward to the next argument. - auto getArguments = [&](unsigned numArguments) -> ArrayRef { - auto args = block->getArguments(); - auto slice = args.drop_front(offset).take_front(numArguments); - offset += numArguments; - return {slice.begin(), slice.end()}; - }; + ArrayRef args = block->getArguments(); // Block iteration position defined by the block index and size. - Value blockIndex = block->getArgument(offset++); - Value blockSize = block->getArgument(offset++); + BlockArgument blockIndex = computeFuncType.blockIndex(args); + BlockArgument blockSize = computeFuncType.blockSize(args); // Constants used below. Value c0 = b.create(0); Value c1 = b.create(1); // Multi-dimensional parallel iteration space defined by the loop trip counts. - ArrayRef tripCounts = getArguments(op.getNumLoops()); + ArrayRef tripCounts = computeFuncType.tripCounts(args); // Compute a product of trip counts to get the size of the flattened // one-dimensional iteration space. @@ -240,12 +335,11 @@ tripCount = b.create(tripCount, tripCounts[i]); // Parallel operation lower bound and step. - ArrayRef lowerBound = getArguments(op.getNumLoops()); - offset += op.getNumLoops(); // skip upper bound arguments - ArrayRef step = getArguments(op.getNumLoops()); + ArrayRef lowerBound = computeFuncType.lowerBounds(args); + ArrayRef step = computeFuncType.steps(args); // Remaining arguments are implicit captures of the parallel operation. - ArrayRef captures = getArguments(block->getNumArguments() - offset); + ArrayRef captures = computeFuncType.captures(args); // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: // blockFirstIndex = blockIndex * blockSize @@ -347,7 +441,7 @@ // Copy the body of the parallel op into the inner-most loop. BlockAndValueMapping mapping; mapping.map(op.getInductionVars(), computeBlockInductionVars); - mapping.map(computeFuncType.captures, captures); + mapping.map(computeFuncType.captures(), captures); for (auto &bodyOp : op.getLoopBody().getOps()) nb.clone(bodyOp, mapping); @@ -358,7 +452,7 @@ workLoopBuilder(0)); b.create(ValueRange()); - return {func, std::move(computeFuncType.captures)}; + return {computeFuncType, func}; } // Creates recursive async dispatch function for the given parallel compute @@ -383,13 +477,13 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, PatternRewriter &rewriter) { OpBuilder::InsertionGuard guard(rewriter); - Location loc = computeFunc.func.getLoc(); + Location loc = computeFunc.func().getLoc(); ImplicitLocOpBuilder b(loc, rewriter); - ModuleOp module = computeFunc.func->getParentOfType(); + ModuleOp module = computeFunc.func()->getParentOfType(); ArrayRef computeFuncInputTypes = - computeFunc.func.type().cast().getInputs(); + computeFunc.func().type().cast().getInputs(); // Compared to the parallel compute function async dispatch function takes // additional !async.group argument. Also instead of a single `blockIndex` it @@ -485,8 +579,9 @@ SmallVector computeFuncOperands = {blockStart}; computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); - b.create(computeFunc.func.sym_name(), - computeFunc.func.getCallableResults(), computeFuncOperands); + b.create(computeFunc.func().sym_name(), + computeFunc.func().getCallableResults(), + computeFuncOperands); b.create(ValueRange()); return func; @@ -515,7 +610,7 @@ operands.append(op.lowerBound().begin(), op.lowerBound().end()); operands.append(op.upperBound().begin(), op.upperBound().end()); operands.append(op.step().begin(), op.step().end()); - operands.append(parallelComputeFunction.captures); + operands.append(parallelComputeFunction.captures()); }; // Check if the block size is one, in this case we can skip the async dispatch @@ -531,8 +626,8 @@ SmallVector operands = {c0, blockSize}; appendBlockComputeOperands(operands); - nb.create(parallelComputeFunction.func.sym_name(), - parallelComputeFunction.func.getCallableResults(), + nb.create(parallelComputeFunction.func().sym_name(), + parallelComputeFunction.func().getCallableResults(), operands); nb.create(); }; @@ -572,7 +667,7 @@ const SmallVector &tripCounts) { MLIRContext *ctx = op->getContext(); - FuncOp compute = parallelComputeFunction.func; + FuncOp compute = parallelComputeFunction.func(); Value c0 = b.create(0); Value c1 = b.create(1); @@ -594,7 +689,7 @@ computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end()); computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end()); computeFuncOperands.append(op.step().begin(), op.step().end()); - computeFuncOperands.append(parallelComputeFunction.captures); + computeFuncOperands.append(parallelComputeFunction.captures()); return computeFuncOperands; }; @@ -639,6 +734,10 @@ ImplicitLocOpBuilder b(op.getLoc(), rewriter); + // Make sure that all constants will be inside the parallel operation body to + // reduce the number of parallel compute function arguments. + cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); + // Compute trip count for each loop induction variable: // tripCount = ceil_div(upperBound - lowerBound, step); SmallVector tripCounts(op.getNumLoops()); @@ -646,8 +745,8 @@ auto lb = op.lowerBound()[i]; auto ub = op.upperBound()[i]; auto step = op.step()[i]; - auto range = b.create(ub, lb); - tripCounts[i] = b.create(range, step); + auto range = b.createOrFold(ub, lb); + tripCounts[i] = b.createOrFold(range, step); } // Compute a product of trip counts to get the 1-dimensional iteration space @@ -703,6 +802,21 @@ ParallelComputeFunction parallelComputeFunction = createParallelComputeFunction(op, rewriter); + // Sink known loop nest constants into the parallel compute function to + // simplify the compute loop. LLVM can't always do that later because of the + // complex IR structure, especially in the async dispatch case. + auto constants = [&](auto values) -> SmallVector { + SmallVector attrs(values.size()); + for (unsigned i = 0; i < values.size(); ++i) + matchPattern(values[i], m_Constant(&attrs[i])); + return attrs; + }; + + parallelComputeFunction.setTripCounts(constants(tripCounts)); + parallelComputeFunction.setLowerBounds(constants(op.lowerBound())); + parallelComputeFunction.setUpperBounds(constants(op.upperBound())); + parallelComputeFunction.setSteps(constants(op.step())); + // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. if (asyncDispatch) { diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt %s \ +// RUN: mlir-opt %s -split-input-file \ // RUN: -async-parallel-for=async-dispatch=true \ // RUN: | FileCheck %s -// RUN: mlir-opt %s \ +// RUN: mlir-opt %s -split-input-file \ // RUN: -async-parallel-for=async-dispatch=false \ // RUN: -canonicalize -inline -symbol-dce \ // RUN: | FileCheck %s @@ -34,3 +34,35 @@ // CHECK: %[[CST:.*]] = arith.constant 1.0{{.*}} : f32 // CHECK: scf.for // CHECK: memref.store %[[CST]], %[[MEMREF]] + +// ----- + +// Check that constant loop bound sunk into the parallel compute function. + +// CHECK-LABEL: func @sink_constant_step( +func @sink_constant_step(%arg0: memref, %lb: index, %ub: index) { + %one = arith.constant 1.0 : f32 + %st = arith.constant 123 : index + + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + memref.store %one, %arg0[%i] : memref + } + + return +} + +// CHECK-LABEL: func private @parallel_compute_fn( +// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, +// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP:arg[0-9]+]]: index, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref +// CHECK-SAME: ) { +// CHECK: %[[CSTEP:.*]] = arith.constant 123 : index +// CHECK-NOT: %[[STEP]] +// CHECK: scf.for %[[I:arg[0-9]+]] +// CHECK: %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]] +// CHECK: %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]] +// CHECK: memref.store \ No newline at end of file