diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -44,10 +44,11 @@ /// (mlir::isLoopParallel can be used to detect a parallel affine.for op.) The /// reductions specified in `parallelReductions` are also parallelized. /// Parallelization will fail in the presence of loop iteration arguments that -/// are not listed in `parallelReductions`. -LogicalResult -affineParallelize(AffineForOp forOp, - ArrayRef parallelReductions = {}); +/// are not listed in `parallelReductions`. `resOp` if non-null is set to the +/// newly created affine.parallel op. +LogicalResult affineParallelize(AffineForOp forOp, + ArrayRef parallelReductions = {}, + AffineParallelOp *resOp = nullptr); /// Hoists out affine.if/else to as high as possible, i.e., past all invariant /// affine.fors/parallel's. Returns success if any hoisting happened; folded` is diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -344,7 +344,8 @@ LogicalResult mlir::affine::affineParallelize(AffineForOp forOp, - ArrayRef parallelReductions) { + ArrayRef parallelReductions, + AffineParallelOp *resOp) { // Fail early if there are iter arguments that are not reductions. unsigned numReductions = parallelReductions.size(); if (numReductions != forOp.getNumIterOperands()) @@ -398,6 +399,8 @@ newPloop.getBody()->eraseArguments(numIVs, numReductions); forOp.erase(); + if (resOp) + *resOp = newPloop; return success(); }