diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -20,6 +20,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -177,6 +177,8 @@ let arguments = (ins Index:$size); let results = (outs Async_GroupType:$result); + let hasCanonicalizeMethod = 1; + let assemblyFormat = "$size `:` type($result) attr-dict"; } diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -245,6 +245,36 @@ return success(); } +//===----------------------------------------------------------------------===// +/// CreateGroupOp +//===----------------------------------------------------------------------===// + +LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, + PatternRewriter &rewriter) { + // Find all `await_all` users of the group. + llvm::SmallVector awaitAllUsers; + + auto isAwaitAll = [&](Operation *op) -> bool { + if (AwaitAllOp awaitAll = dyn_cast(op)) { + awaitAllUsers.push_back(awaitAll); + return true; + } + return false; + }; + + // Check if all users of the group are `await_all` operations. + if (!llvm::all_of(op->getUsers(), isAwaitAll)) + return failure(); + + // If group is only awaited without adding anything to it, we can safely erase + // the create operation and all users. + for (AwaitAllOp awaitAll : awaitAllUsers) + rewriter.eraseOp(awaitAll); + rewriter.eraseOp(op); + + return success(); +} + //===----------------------------------------------------------------------===// /// AwaitOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -380,8 +380,9 @@ // call @parallel_compute_fn(%block_start, %block_size, ...); // } // -static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, - PatternRewriter &rewriter) { +static FuncOp +createAsingleBlockDispatchFunction(ParallelComputeFunction &computeFunc, + PatternRewriter &rewriter) { OpBuilder::InsertionGuard guard(rewriter); Location loc = computeFunc.func.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); @@ -492,17 +493,17 @@ } // Launch async dispatch of the parallel compute function. -static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, - ParallelComputeFunction ¶llelComputeFunction, - scf::ParallelOp op, Value blockSize, - Value blockCount, - const SmallVector &tripCounts) { +static void +doAsingleBlockDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, + ParallelComputeFunction ¶llelComputeFunction, + scf::ParallelOp op, Value blockSize, Value blockCount, + const SmallVector &tripCounts) { MLIRContext *ctx = op->getContext(); // Add one more level of indirection to dispatch parallel compute functions // using async operations and recursive work splitting. FuncOp asyncDispatchFunction = - createAsyncDispatchFunction(parallelComputeFunction, rewriter); + createAsingleBlockDispatchFunction(parallelComputeFunction, rewriter); Value c0 = b.create(0); Value c1 = b.create(1); @@ -513,18 +514,48 @@ Value groupSize = b.create(blockCount, c1); Value group = b.create(GroupType::get(ctx), groupSize); - // Pack the async dispath function operands to launch the work splitting. - SmallVector asyncDispatchOperands = {group, c0, blockCount, blockSize}; - asyncDispatchOperands.append(tripCounts); - asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end()); - asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end()); - asyncDispatchOperands.append(op.step().begin(), op.step().end()); - asyncDispatchOperands.append(parallelComputeFunction.captures); + // Appends operands shared by async dispatch and parallel compute functions to + // the given operands vector. + auto appendBlockComputeOperands = [&](SmallVector &operands) { + operands.append(tripCounts); + operands.append(op.lowerBound().begin(), op.lowerBound().end()); + operands.append(op.upperBound().begin(), op.upperBound().end()); + operands.append(op.step().begin(), op.step().end()); + operands.append(parallelComputeFunction.captures); + }; + + // Check if the block size is one, in this case we can skip the async dispatch + // completely. If this will be known statically, then canonicalization will + // erase async group operations. + Value isSingleBlock = b.create(CmpIPredicate::eq, blockCount, c1); + + auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + + // Call parallel compute function for the single block. + SmallVector operands = {c0, blockSize}; + appendBlockComputeOperands(operands); + + nb.create(parallelComputeFunction.func.sym_name(), + parallelComputeFunction.func.getCallableResults(), + operands); + nb.create(); + }; + + auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { + ImplicitLocOpBuilder nb(loc, nestedBuilder); + + // Launch async dispatch function for [0, blockCount) range. + SmallVector operands = {group, c0, blockCount, blockSize}; + appendBlockComputeOperands(operands); + + nb.create(asyncDispatchFunction.sym_name(), + asyncDispatchFunction.getCallableResults(), operands); + nb.create(); + }; - // Launch async dispatch function for [0, blockCount) range. - b.create(asyncDispatchFunction.sym_name(), - asyncDispatchFunction.getCallableResults(), - asyncDispatchOperands); + // Dispatch either single block compute function, or launch async dispatch. + b.create(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch); // Wait for the completion of all parallel compute operations. b.create(group); @@ -649,8 +680,8 @@ // Dispatch parallel compute function using async recursive work splitting, or // by submitting compute task sequentially from a caller thread. if (asyncDispatch) { - doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, - blockCount, tripCounts); + doAsingleBlockDispatch(b, rewriter, parallelComputeFunction, op, blockSize, + blockCount, tripCounts); } else { doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, blockCount, tripCounts); diff --git a/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir b/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-parallel-for-canonicalize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s \ +// RUN: -async-parallel-for=async-dispatch=true \ +// RUN: -canonicalize -inline -symbol-dce \ +// RUN: | FileCheck %s + +// RUN: mlir-opt %s \ +// RUN: -async-parallel-for=async-dispatch=false \ +// RUN: -canonicalize -inline -symbol-dce \ +// RUN: | FileCheck %s + +// Check that if we statically know that the parallel operation has a single +// block then all async operations will be canonicalized away and we will +// end up with a single synchonous compute function call. + +// CHECK-LABEL: @loop_1d( +// CHECK: %[[MEMREF:.*]]: memref +func @loop_1d(%arg0: memref) { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C100:.*]] = constant 100 : index + // CHECK-DAG: %[[ONE:.*]] = constant 1.000000e+00 : f32 + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] + // CHECK: memref.store %[[ONE]], %[[MEMREF]][%[[I]]] + %lb = constant 0 : index + %ub = constant 100 : index + %st = constant 1 : index + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + %one = constant 1.0 : f32 + memref.store %one, %arg0[%i] : memref + } + + return +}