diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -214,10 +214,11 @@ typename std::conditional::value, AffineIndexedValue, StdIndexedValue>::type; - static void doit(ArrayRef loopRanges, - ArrayRef iteratorTypes, - function_ref bodyBuilderFn, - Optional = None); + static void + doit(ArrayRef loopRanges, ValueRange iterArgInitValues, + ArrayRef iteratorTypes, + function_ref bodyBuilderFn, + Optional = None); }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h @@ -32,6 +32,10 @@ scf::ValueVector loopNestBuilder( Value lb, Value ub, Value step, ValueRange iterArgInitValues, function_ref fun = nullptr); +scf::ValueVector loopNestBuilder( + ValueRange lbs, ValueRange ubs, ValueRange steps, + ValueRange iterArgInitValues, + function_ref fun = nullptr); /// Adapters for building if conditions using the builder and the location /// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -515,9 +515,12 @@ map, getViewSizes(builder, linalgOp)); SmallVector allIvs; GenerateLoopNest::doit( - loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) { + loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(), + [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { + assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); emitScalarImplementation(allIvs, linalgOp); + return scf::ValueVector{}; }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -386,8 +386,8 @@ if (!options.interchangeVector.empty()) applyPermutationToVector(iteratorTypes, options.interchangeVector); GenerateLoopNest::doit( - loopRanges, iteratorTypes, - [&](ValueRange localIvs) { + loopRanges, /*iterArgInitValues*/ {}, iteratorTypes, + [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); ivs.assign(localIvs.begin(), localIvs.end()); @@ -406,6 +406,7 @@ auto operands = getAssumedNonViewOperands(op); views.append(operands.begin(), operands.end()); res = op.clone(b, loc, views); + return scf::ValueVector{}; }, options.distribution); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -194,20 +194,23 @@ /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit( - ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn, + ArrayRef loopRanges, ValueRange iterArgInitValues, + ArrayRef iteratorTypes, + function_ref bodyBuilderFn, Optional) { SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); - edsc::loopNestBuilder(lbs, ubs, steps, bodyBuilderFn); + edsc::loopNestBuilder(lbs, ubs, steps, iterArgInitValues, bodyBuilderFn); } /// Specialization to build affine "for" nest. template <> void GenerateLoopNest::doit( - ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn, + ArrayRef loopRanges, ValueRange iterArgInitValues, + ArrayRef iteratorTypes, + function_ref bodyBuilderFn, Optional) { + assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -220,7 +223,11 @@ constantSteps.push_back(op.getValue()); } - edsc::affineLoopNestBuilder(lbs, ubs, constantSteps, bodyBuilderFn); + auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) { + bodyBuilderFn(ivs, {}); + }; + edsc::affineLoopNestBuilder(lbs, ubs, constantSteps, + bodyBuilderWithoutIterArgsFn); } /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. @@ -357,9 +364,11 @@ /// Specialization for generating a mix of parallel and sequential scf loops. template <> void GenerateLoopNest::doit( - ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn, + ArrayRef loopRanges, ValueRange iterArgInitValues, + ArrayRef iteratorTypes, + function_ref bodyBuilderFn, Optional distributionOptions) { + assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && "expected iterator type for all ranges"); @@ -405,7 +414,11 @@ } } ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); - generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs, + auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) { + bodyBuilderFn(ivs, {}); + }; + generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, + bodyBuilderWithoutIterArgsFn, ivs, distributionMethod); assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); diff --git a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp --- a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp @@ -61,6 +61,25 @@ }); } +mlir::scf::ValueVector mlir::edsc::loopNestBuilder( + ValueRange lbs, ValueRange ubs, ValueRange steps, + ValueRange iterArgInitValues, + function_ref fun) { + // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into + // the expected function interface. + assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); + return mlir::scf::buildLoopNest( + ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs, + steps, iterArgInitValues, + [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) { + ScopedContext context(builder, loc); + if (fun) + return fun(ivs, args); + return scf::ValueVector(iterArgInitValues.begin(), + iterArgInitValues.end()); + }); +} + static std::function wrapIfBody(function_ref body, TypeRange expectedTypes) { (void)expectedTypes;