diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -232,11 +232,11 @@ else interchangedIvs.assign(ivs.begin(), ivs.end()); - assert(op.getNumOutputTensors() == iterArgs.size() && + assert(op.getOutputTensorOperands().size() == iterArgs.size() && "num output tensors must match number of loop iter arguments"); - auto operands = llvm::to_vector<4>(op.getInputs()); - SmallVector outputBuffers = op.getOutputBuffers(); + SmallVector operands = op.getInputOperands(); + SmallVector outputBuffers = op.getOutputBufferOperands(); // TODO: thanks to simplifying assumption we do not need to worry about // order of output buffers and tensors: there is only ever one kind. assert(outputBuffers.empty() || iterArgs.empty()); @@ -252,7 +252,7 @@ // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. SmallVector resultTensorTypes; - for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) + for (OpOperand *opOperand : op.getOutputTensorOperands()) resultTensorTypes.push_back( tiledOperands[opOperand->getOperandNumber()].getType()); @@ -260,7 +260,7 @@ // Insert a subtensor_insert for each output tensor. unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) { + for (OpOperand *opOperand : op.getOutputTensorOperands()) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -200,7 +200,7 @@ bodyBuilderFn, Optional distributionOptions, ArrayRef distributionTypes) { - auto iterArgInitValues = linalgOp.getOutputTensors(); + SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); // Create procInfo so it dominates loops, if appropriate. SmallVector procInfo; SmallVector distributionMethod; @@ -248,7 +248,7 @@ ValueRange)> bodyBuilderFn, Optional, ArrayRef) { - auto iterArgInitValues = linalgOp.getOutputTensors(); + SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -285,14 +285,17 @@ auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange inputs, ValueRange outputs) { - scf::ValueVector results = bodyBuilderFn(nestedBuilder, nestedLoc, ivs, - linalgOp.getOutputTensors()); + SmallVector outputTensors = linalgOp.getOutputTensorOperands(); + scf::ValueVector results = + bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors); nestedBuilder.create(nestedLoc, results); }; - auto tiledLoop = b.create( - loc, lbs, ubs, steps, linalgOp.getInputs(), linalgOp.getOutputs(), - b.getArrayAttr(iteratorTypes), wrappedBuilderFn); + SmallVector inputOperands = linalgOp.getInputOperands(); + SmallVector outputOperands = linalgOp.getOutputOperands(); + auto tiledLoop = + b.create(loc, lbs, ubs, steps, inputOperands, outputOperands, + b.getArrayAttr(iteratorTypes), wrappedBuilderFn); if (!distributionTypes.empty()) tiledLoop.setDistributionTypes(b, distributionTypes); @@ -300,11 +303,9 @@ auto isInsideTiledLoop = [&](OpOperand &operand) { return operand.getOwner()->getBlock() == tiledLoop.getBody(); }; - for (auto it : - llvm::zip(linalgOp.getInputs(), tiledLoop.getRegionInputArgs())) + for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs())) std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop); - for (auto it : - llvm::zip(linalgOp.getOutputs(), tiledLoop.getRegionOutputArgs())) + for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs())) std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop); } @@ -452,7 +453,7 @@ bodyBuilderFn, Optional distributionOptions, ArrayRef distributionTypes) { - auto iterArgInitValues = linalgOp.getOutputTensors(); + SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && @@ -509,7 +510,7 @@ SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, - ArrayRef tiledOperands, + ArrayRef valuesToTile, ValueRange ivs, ValueRange tileSizes, ArrayRef sizeBounds) { assert(ivs.size() == static_cast(llvm::count_if( @@ -533,20 +534,22 @@ LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n"); } + assert(valuesToTile.size() == linalgOp.getNumInputsAndOutputs() && + "expected one value to tile for every operand"); MLIRContext *context = b.getContext(); SmallVector tiledShapes; - tiledShapes.reserve(tiledOperands.size()); - for (auto en : llvm::enumerate(tiledOperands)) { - Value shapedOp = en.value(); + tiledShapes.reserve(valuesToTile.size()); + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); - ShapedType shapedType = shapedOp.getType().cast(); - unsigned rank = shapedType.getRank(); - AffineMap map = linalgOp.getIndexingMap(en.index()); + int64_t rank = linalgOp.getRank(opOperand); + ArrayRef shape = linalgOp.getShape(opOperand); + AffineMap map = linalgOp.getTiedIndexingMap(opOperand); // If the shape is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { tiledShapes.push_back(shapedOp); - LLVM_DEBUG(llvm::dbgs() - << ": not tiled: use shape: " << shapedType << "\n"); + LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " + << opOperand->get().getType() << "\n"); continue; } LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); @@ -583,7 +586,7 @@ // The size of the subview / subtensor should be trimmed to avoid // out-of-bounds accesses, unless we statically know the subshape size // divides the shape size evenly. - int64_t shapeSize = shapedType.getDimSize(r); + int64_t shapeSize = shape[r]; auto sizeCst = size.getDefiningOp(); if (ShapedType::isDynamic(shapeSize) || !sizeCst || (shapeSize % sizeCst.getValue()) != 0) { @@ -610,7 +613,7 @@ strides.push_back(b.getIndexAttr(1)); } - if (shapedType.isa()) + if (opOperand->get().getType().isa()) tiledShapes.push_back( b.create(loc, shapedOp, offsets, sizes, strides)); else