diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -265,6 +265,12 @@ type); func.setPrivate(); + // Always inline parallel compute function. It gets inlined twice: into the + // `async.execute` body and into the main function computing result for the + // first block. + auto alwaysinline = rewriter.getStringAttr("alwaysinline"); + func->setAttr("passthrough", rewriter.getArrayAttr({alwaysinline})); + // Insert function into the module symbol table and assign it unique name. SymbolTable symbolTable(module); symbolTable.insert(func); @@ -479,6 +485,11 @@ FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type); func.setPrivate(); + // TODO: We do not inline async dispatch function, in some cases it leads to + // performance regressions because of poor vectorization of compute function. + auto noinline = rewriter.getStringAttr("noinline"); + func->setAttr("passthrough", rewriter.getArrayAttr({noinline})); + // Insert function into the module symbol table and assign it unique name. SymbolTable symbolTable(module); symbolTable.insert(func); @@ -790,13 +801,18 @@ SmallVector numIterations(op.getNumLoops()); numIterations.back() = getInt(staticBounds.tripCounts.back()); + // Check if we can potentially unroll the inner-most loop. + if (numIterations.back() > 0 && numIterations.back() <= maxIterations) + numUnrollableLoops++; + + // Check if we can potentially unroll other loops in the loop nest. for (int i = op.getNumLoops() - 2; i >= 0; --i) { int64_t tripCount = getInt(staticBounds.tripCounts[i]); int64_t innerIterations = numIterations[i + 1]; numIterations[i] = tripCount * innerIterations; // Update the number of inner loops that we can potentially unroll. - if (innerIterations > 0 && innerIterations <= maxIterations) + if (numIterations[i] > 0 && numIterations[i] <= maxIterations) numUnrollableLoops++; } diff --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir @@ -36,7 +36,7 @@ // CHECK-SAME: %[[GROUP:arg0]]: !async.group, // CHECK-SAME: %[[BLOCK_START:arg1]]: index // CHECK-SAME: %[[BLOCK_END:arg2]]: index -// CHECK-SAME: ) +// CHECK-SAME: ) attributes {passthrough = ["noinline"]} // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: scf.while (%[[S0:.*]] = %[[BLOCK_START]], diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -30,7 +30,7 @@ // CHECK-SAME: %[[UB:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref -// CHECK-SAME: ) { +// CHECK-SAME: ) attributes {passthrough = ["alwaysinline"]} { // CHECK: %[[CST:.*]] = arith.constant 1.0{{.*}} : f32 // CHECK: scf.for // CHECK: memref.store %[[CST]], %[[MEMREF]] @@ -59,7 +59,7 @@ // CHECK-SAME: %[[UB:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref -// CHECK-SAME: ) { +// CHECK-SAME: ) attributes {passthrough = ["alwaysinline"]} { // CHECK: %[[CSTEP:.*]] = arith.constant 123 : index // CHECK-NOT: %[[STEP]] // CHECK: scf.for %[[I:arg[0-9]+]] @@ -99,7 +99,7 @@ // CHECK-SAME: %[[STEP0:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref -// CHECK-SAME: ) { +// CHECK-SAME: ) attributes {passthrough = ["alwaysinline"]} { // CHECK: scf.for %[[I:arg[0-9]+]] // CHECK: arith.select // CHECK: scf.for %[[J:arg[0-9]+]] @@ -117,7 +117,7 @@ // CHECK-SAME: %[[STEP0:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref -// CHECK-SAME: ) { +// CHECK-SAME: ) attributes {passthrough = ["alwaysinline"]} { // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C10:.*]] = arith.constant 10 : index