diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -190,6 +190,10 @@ ModuleOp module = op->getParentOfType(); + // Make sure that all constants will be inside the parallel operation body to + // reduce the number of parallel compute function arguments. + cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); + ParallelComputeFunctionType computeFuncType = getParallelComputeFunctionType(op, rewriter); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -218,6 +218,10 @@ MLIRContext *ctx = module.getContext(); Location loc = execute.getLoc(); + // Make sure that all constants will be inside the outlined async function to + // reduce the number of function arguments. + cloneConstantsIntoTheRegion(execute.body()); + // Collect all outlined function inputs. SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ AsyncRuntimeRefCounting.cpp AsyncRuntimeRefCountingOpt.cpp AsyncToAsyncRuntime.cpp + PassDetail.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.h b/mlir/lib/Dialect/Async/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Async/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.h @@ -25,6 +25,24 @@ #define GEN_PASS_CLASSES #include "mlir/Dialect/Async/Passes.h.inc" +// -------------------------------------------------------------------------- // +// Utility functions shared by Async Transformations. +// -------------------------------------------------------------------------- // + +// Forward declarations. +class OpBuilder; + +namespace async { + +/// Clone ConstantLike operations that are defined above the given region and +/// have users in the region into the region entry block. We do that to reduce +/// the number of function arguments when we outline `async.execute` and +/// `scf.parallel` operations body into functions. +void cloneConstantsIntoTheRegion(Region ®ion); +void cloneConstantsIntoTheRegion(Region ®ion, OpBuilder &builder); + +} // namespace async + } // namespace mlir #endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp @@ -0,0 +1,43 @@ +//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; + +void mlir::async::cloneConstantsIntoTheRegion(Region ®ion) { + OpBuilder builder(®ion); + cloneConstantsIntoTheRegion(region, builder); +} + +void mlir::async::cloneConstantsIntoTheRegion(Region ®ion, + OpBuilder &builder) { + // Values implicitly captured by the region. + llvm::SetVector captures; + getUsedValuesDefinedAbove(region, region, captures); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(®ion.front()); + + // Clone ConstantLike operations into the region. + for (Value capture : captures) { + Operation *op = capture.getDefiningOp(); + if (!op || !op->hasTrait()) + continue; + + Operation *cloned = builder.clone(*op); + + for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) { + Value orig = std::get<0>(tuple); + Value replacement = std::get<1>(tuple); + replaceAllUsesInRegionWith(orig, replacement, region); + } + } +} diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -89,13 +89,14 @@ } // Function outlined from the inner async.execute operation. -// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index) +// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) // CHECK-SAME: -> !llvm.ptr // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin // CHECK: call @mlirAsyncRuntimeExecute // CHECK: llvm.intr.coro.suspend -// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32> // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]]) // Function outlined from the outer async.execute operation. diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s \ +// RUN: -async-parallel-for=async-dispatch=true \ +// RUN: | FileCheck %s + +// RUN: mlir-opt %s \ +// RUN: -async-parallel-for=async-dispatch=false \ +// RUN: -canonicalize -inline -symbol-dce \ +// RUN: | FileCheck %s + +// Check that constants defined outside of the `scf.parallel` body will be +// sunk into the parallel compute function to avoid blowing up the number +// of parallel compute function arguments. + +// CHECK-LABEL: func @clone_constant( +func @clone_constant(%arg0: memref, %lb: index, %ub: index, %st: index) { + %one = constant 1.0 : f32 + + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + memref.store %one, %arg0[%i] : memref + } + + return +} + +// CHECK-LABEL: func private @parallel_compute_fn( +// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, +// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP:arg[0-9]+]]: index, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref +// CHECK-SAME: ) { +// CHECK: %[[CST:.*]] = constant 1.0{{.*}} : f32 +// CHECK: scf.for +// CHECK: memref.store %[[CST]], %[[MEMREF]] diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -406,3 +406,26 @@ // Check that structured control flow lowered to CFG. // CHECK-NOT: scf.if // CHECK: cond_br %[[FLAG]] + +// ----- +// Constants captured by the async.execute region should be cloned into the +// outline async execute function. + +// CHECK-LABEL: @clone_constants +func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) { + %c0 = constant 0 : index + %token = async.execute { + memref.store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + async.await %token : !async.token + return +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn( +// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32> +// CHECK-SAME: ) -> !async.token +// CHECK: %[[CST:.*]] = constant 0 : index +// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]