diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -354,8 +354,9 @@ Optional getMemoryFootprintBytes(AffineForOp forOp, int memorySpace = -1); -/// Returns true if `forOp' is a parallel loop. -bool isLoopParallel(AffineForOp forOp); +/// Returns true if `forOp' is a parallel loop. If `reductionsAreParallel` is +/// set, treats loop iteration arguments as not preventing parallelization. +bool isLoopParallel(AffineForOp forOp, bool reductionsAreParallel = false); /// Simplify the integer set by simplifying the underlying affine expressions by /// flattening and some simple inference. Also, drop any duplicate constraints. diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -123,6 +123,9 @@ Option<"maxNested", "max-nested", "unsigned", /*default=*/"-1u", "Maximum number of nested parallel loops to produce. " "Defaults to unlimited (UINT_MAX).">, + Option<"parallelReductions", "parallel-reductions", "bool", + /*default=*/"false", + "Whether to parallelize reduction loops. Defaults to false."> ]; } diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -29,7 +29,7 @@ /// Replaces parallel affine.for op with 1-d affine.parallel op. /// mlir::isLoopParallel detect the parallel affine.for ops. /// There is no cost model currently used to drive this parallelization. -void affineParallelize(AffineForOp forOp); +LogicalResult affineParallelize(AffineForOp forOp); /// Hoists out affine.if/else to as high as possible, i.e., past all invariant /// affine.fors/parallel's. Returns success if any hoisting happened; folded` is diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -1269,11 +1269,11 @@ } /// Returns true if 'forOp' is parallel. -bool mlir::isLoopParallel(AffineForOp forOp) { +bool mlir::isLoopParallel(AffineForOp forOp, bool reductionsAreParallel) { // Loop is not parallel if it has SSA loop-carried dependences. // TODO: Conditionally support reductions and other loop-carried dependences // that could be handled in the context of a parallel loop. - if (forOp.getNumIterOperands() > 0) + if (forOp.getNumIterOperands() > 0 && !reductionsAreParallel) return false; // Collect all load and store ops in loop nest rooted at 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp @@ -43,7 +43,7 @@ // the front of a deque. std::deque parallelizableLoops; f.walk([&](AffineForOp loop) { - if (isLoopParallel(loop)) + if (isLoopParallel(loop, parallelReductions)) parallelizableLoops.push_front(loop); }); @@ -56,8 +56,15 @@ ++numParentParallelOps; } - if (numParentParallelOps < maxNested) - affineParallelize(loop); + if (numParentParallelOps < maxNested) { + if (failed(affineParallelize(loop))) { + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] failed to parallelize\n" + << loop); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] too many nested loops\n" + << loop); + } } } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -129,9 +130,59 @@ return hoistedIfOp; } +/// Get the value that is being reduced by `pos`-th reduction in the loop if +/// such a reduction can be performed by affine parallel loops. This assumes +/// floating-point operations are commutative. On success, `kind` will be the +/// reduction kind suitable for use in affine parallel loop builder. The list +/// of supported cases must be kept in sync with `affineParallelize`. If the +/// reduction is not supported, returns null. +static Value getSupportedReduction(AffineForOp forOp, unsigned pos, + AtomicRMWKind &kind) { + auto yieldOp = cast(forOp.getBody()->back()); + Value yielded = yieldOp.operands()[pos]; + Operation *definition = yielded.getDefiningOp(); + if (!definition) + return nullptr; + if (!forOp.getRegionIterArgs()[pos].hasOneUse()) + return nullptr; + + Optional maybeKind = + TypeSwitch>(definition) + .Case([](Operation *) { return AtomicRMWKind::addf; }) + .Case([](Operation *) { return AtomicRMWKind::mulf; }) + .Case([](Operation *) { return AtomicRMWKind::addi; }) + .Case([](Operation *) { return AtomicRMWKind::muli; }) + .Default([](Operation *) -> Optional { + return llvm::None; + }); + if (!maybeKind) + return nullptr; + + kind = *maybeKind; + if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos]) + return definition->getOperand(1); + if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos]) + return definition->getOperand(0); + + return nullptr; +} + /// Replace affine.for with a 1-d affine.parallel and clone the former's body -/// into the latter while remapping values. -void mlir::affineParallelize(AffineForOp forOp) { +/// into the latter while remapping values. Also parallelize reductions if +/// supported for all reductions. +LogicalResult mlir::affineParallelize(AffineForOp forOp) { + unsigned numReductions = forOp.getNumRegionIterArgs(); + SmallVector reductionKinds; + SmallVector reducedValues; + reductionKinds.reserve(numReductions); + reducedValues.reserve(numReductions); + for (unsigned i = 0; i < numReductions; ++i) { + reducedValues.push_back( + getSupportedReduction(forOp, i, reductionKinds.emplace_back())); + if (!reducedValues.back()) + return failure(); + } + Location loc = forOp.getLoc(); OpBuilder outsideBuilder(forOp); @@ -148,7 +199,7 @@ if (needsMax || needsMin) { if (forOp->getParentOp() && !forOp->getParentOp()->hasTrait()) - return; + return failure(); identityMap = AffineMap::getMultiDimIdentityMap(1, loc->getContext()); } @@ -169,11 +220,41 @@ // Creating empty 1-D affine.parallel op. AffineParallelOp newPloop = outsideBuilder.create( - loc, llvm::None, llvm::None, lowerBoundMap, lowerBoundOperands, - upperBoundMap, upperBoundOperands); - // Steal the body of the old affine for op and erase it. + loc, ValueRange(reducedValues).getTypes(), reductionKinds, lowerBoundMap, + lowerBoundOperands, upperBoundMap, upperBoundOperands); + // Steal the body of the old affine for op. newPloop.region().takeBody(forOp.region()); + Operation *yieldOp = &newPloop.getBody()->back(); + + // Handle the initial values of reductions because the parallel loop always + // starts from the neutral value. + SmallVector newResults; + newResults.reserve(numReductions); + for (unsigned i = 0; i < numReductions; ++i) { + Value init = forOp.getIterOperands()[i]; + // This works because we are only handling single-op reductions at the + // moment. A switch on reduction kind or a mechanism to collect operations + // participating in the reduction will be necessary for multi-op reductions. + Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp(); + assert(reductionOp && "yielded value is expected to be produced by an op"); + outsideBuilder.getInsertionBlock()->getOperations().splice( + outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(), + reductionOp); + reductionOp->setOperands({init, newPloop->getResult(i)}); + forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0)); + } + + // Update the loop terminator to yield reduced values bypassing the reduction + // operation itself (now moved outside of the loop) and erase the block + // arguments that correspond to reductions. Note that the loop always has one + // "main" induction variable whenc coming from a non-parallel for. + unsigned numIVs = 1; + yieldOp->setOperands(reducedValues); + newPloop.getBody()->eraseArguments( + llvm::to_vector<4>(llvm::seq(numIVs, numReductions + numIVs))); + forOp.erase(); + return success(); } // Returns success if any hoisting happened. diff --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir --- a/mlir/test/Dialect/Affine/parallelize.mlir +++ b/mlir/test/Dialect/Affine/parallelize.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize| FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize | FileCheck %s // RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='max-nested=1' | FileCheck --check-prefix=MAX-NESTED %s +// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='parallel-reductions=1' | FileCheck --check-prefix=REDUCE %s // CHECK-LABEL: func @reduce_window_max() { func @reduce_window_max() { @@ -159,10 +160,12 @@ return } -// CHECK-LABEL: @unsupported_iter_args -func @unsupported_iter_args(%in: memref<10xf32>) { +// CHECK-LABEL: @iter_args +// REDUCE-LABEL: @iter_args +func @iter_args(%in: memref<10xf32>) { %cst = constant 0.000000e+00 : f32 // CHECK-NOT: affine.parallel + // REDUCE: affine.parallel (%{{.*}}) = (0) to (10) reduce ("addf") %final_red = affine.for %i = 0 to 10 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%i] : memref<10xf32> %add = addf %red_iter, %ld : f32 @@ -171,12 +174,15 @@ return } -// CHECK-LABEL: @unsupported_nested_iter_args -func @unsupported_nested_iter_args(%in: memref<20x10xf32>) { +// CHECK-LABEL: @nested_iter_args +// REDUCE-LABEL: @nested_iter_args +func @nested_iter_args(%in: memref<20x10xf32>) { %cst = constant 0.000000e+00 : f32 // CHECK: affine.parallel affine.for %i = 0 to 20 { - // CHECK: affine.for + // CHECK-NOT: affine.parallel + // REDUCE: affine.parallel + // REDUCE: reduce ("addf") %final_red = affine.for %j = 0 to 10 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%i, %j] : memref<20x10xf32> %add = addf %red_iter, %ld : f32