diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -263,7 +263,7 @@ /// Utility class used to generate nested loops with ranges described by /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` /// is used to generate the body of the innermost loop. It is passed a range -/// of loop induction variables and a range of iterArgs. +/// of loop induction variables and a range of operand values to use. template struct GenerateLoopNest { static void doit(OpBuilder &b, Location loc, ArrayRef loopRanges, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -431,8 +431,9 @@ GenerateLoopNest::doit( rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - assert(iterArgs.empty() && "unexpected iterArgs"); + ValueRange operandValuesToUse) -> scf::ValueVector { + assert(operandValuesToUse == linalgOp->getOperands() && + "expect operands are captured and not passed by loop argument"); allIvs.append(ivs.begin(), ivs.end()); llvm::TypeSwitch(linalgOp) .Case( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -227,9 +227,9 @@ // 2. Create the tiled loops. LinalgOp res = op; SmallVector ivs, tensorResults; - auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, - ValueRange localIvs, - ValueRange iterArgs) -> scf::ValueVector { + auto tiledLoopBodyBuilder = + [&](OpBuilder &b, Location loc, ValueRange localIvs, + ValueRange operandValuesToUse) -> scf::ValueVector { ivs.assign(localIvs.begin(), localIvs.end()); // When an `interchangeVector` is present, it has been applied to the @@ -241,20 +241,16 @@ else interchangedIvs.assign(ivs.begin(), ivs.end()); - assert(op.getOutputTensorOperands().size() == iterArgs.size() && - "num output tensors must match number of loop iter arguments"); - - SmallVector operands = op.getInputOperands(); - SmallVector outputBuffers = op.getOutputBufferOperands(); - // TODO: thanks to simplifying assumption we do not need to worry about - // order of output buffers and tensors: there is only ever one kind. - assert(outputBuffers.empty() || iterArgs.empty()); - operands.append(outputBuffers.begin(), outputBuffers.end()); - operands.append(iterArgs.begin(), iterArgs.end()); + // Tile the `operandValuesToUse` that either match the `op` operands + // themselves or the tile loop arguments forwarding them. + assert(operandValuesToUse.size() == + static_cast(op.getNumInputsAndOutputs()) && + "expect the number of operands and inputs and outputs to match"); + SmallVector valuesToTile = operandValuesToUse; auto sizeBounds = applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); SmallVector tiledOperands = makeTiledShapes( - b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); + b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds); // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -225,7 +225,18 @@ SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); LoopNest loopNest = mlir::scf::buildLoopNest( - b, loc, lbs, ubs, steps, iterArgInitValues, bodyBuilderFn); + b, loc, lbs, ubs, steps, iterArgInitValues, + [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { + assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() && + "expect the number of output tensors and iter args to match"); + SmallVector operandValuesToUse = + linalgOp.getInputAndOutputOperands(); + if (!iterArgs.empty()) { + operandValuesToUse = linalgOp.getInputOperands(); + operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); + } + return bodyBuilderFn(b, loc, ivs, operandValuesToUse); + }); if (!distributionOptions || loopNest.loops.empty()) return; @@ -268,7 +279,9 @@ mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, [&](OpBuilder &b, Location loc, ValueRange ivs) { - bodyBuilderFn(b, loc, ivs, {}); + SmallVector operandValuesToUse = + linalgOp.getInputAndOutputOperands(); + bodyBuilderFn(b, loc, ivs, operandValuesToUse); }); } @@ -289,9 +302,10 @@ auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange inputs, ValueRange outputs) { - SmallVector outputTensors = linalgOp.getOutputTensorOperands(); + SmallVector operandValuesToUse = inputs; + operandValuesToUse.append(outputs.begin(), outputs.end()); scf::ValueVector results = - bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors); + bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse); nestedBuilder.create(nestedLoc, results); }; @@ -302,15 +316,6 @@ b.getArrayAttr(iteratorTypes), wrappedBuilderFn); if (!distributionTypes.empty()) tiledLoop.setDistributionTypes(b, distributionTypes); - - // Replace inputs/outputs with the corresponding region args. - auto isInsideTiledLoop = [&](OpOperand &operand) { - return operand.getOwner()->getBlock() == tiledLoop.getBody(); - }; - for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs())) - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop); - for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs())) - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop); } /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. @@ -505,7 +510,9 @@ generateParallelLoopNest( b, loc, lbs, ubs, steps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange ivs) { - bodyBuilderFn(b, loc, ivs, {}); + SmallVector operandValuesToUse = + linalgOp.getInputAndOutputOperands(); + bodyBuilderFn(b, loc, ivs, operandValuesToUse); }, ivs, distributionMethod);