diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -30,20 +30,6 @@ class LinalgOp; -// TOFO: allow an extra ValueRange to specify an indexing and allow -// non-hyperrectangular shapes. -using LoopRangeBuilder = - std::function(ImplicitLocOpBuilder)>; - -/// Provide a very simple inference procedure to build the loop ranges from the -/// op and its operands. This only works with permutation affine maps and -/// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`. -/// A more advanced Tensor-Comprehension like inference is possible but has -/// proven to be ambiguous in unfavorable case. -/// As a consequence, we relax the default behavior very conservatively and -/// provide an op-specified hook so that Linalg ops may override the behavior. -LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op); - /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR /// type names: diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -208,6 +208,8 @@ /// necessary. Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, OpFoldResult opFoldResult); +Value materializeOpFoldResult(OpBuilder &b, Location loc, + OpFoldResult opFoldResult); /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -24,9 +24,9 @@ /// operands into a list of triples. Such a list can be more convenient to /// manipulate. struct Range { - Value offset; - Value size; - Value stride; + OpFoldResult offset; + OpFoldResult size; + OpFoldResult stride; }; class OffsetSizeAndStrideOpInterface; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1346,6 +1346,26 @@ genericOp, "illegal to collapse specified dimensions"); } + // Bail on non-canonical ranges. + SmallVector loopRanges = + cast(genericOp.getOperation()) + .createLoopRanges(rewriter, genericOp.getLoc()); + auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { + if (auto attr = ofr.dyn_cast()) + return attr.cast().getInt() == value; + llvm::APInt actual; + return matchPattern(ofr.get(), m_ConstantInt(&actual)) && + actual.getSExtValue() == value; + }; + if (!llvm::all_of(loopRanges, [&](Range range) { + return opFoldIsConstantValue(range.offset, 0) && + opFoldIsConstantValue(range.stride, 1); + })) { + return rewriter.notifyMatchFailure( + genericOp, + "expected all loop ranges to have zero start and unit stride"); + } + // Get the iterator types for the operand. SmallVector iteratorTypes = getCollapsedOpIteratorTypes( genericOp.iterator_types().getValue(), collapsingInfo); @@ -1390,17 +1410,10 @@ // Collect the loop range of the generic op. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(collapsedGenericOp); - SmallVector loopRanges = - cast(genericOp.getOperation()) - .createLoopRanges(rewriter, genericOp.getLoc()); - assert(llvm::all_of(loopRanges, - [](Range range) { - return matchPattern(range.offset, m_Zero()) && - matchPattern(range.stride, m_One()); - }) && - "expected all loop ranges to have zero start and unit stride"); - SmallVector loopBound = llvm::to_vector( - llvm::map_range(loopRanges, [](Range range) { return range.size; })); + SmallVector loopBound = + llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) { + return materializeOpFoldResult(rewriter, loc, range.size); + })); generateCollapsedIndexingRegion(loc, &collapsedGenericOp->getRegion(0).front(), collapsingInfo, loopBound, rewriter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -117,7 +117,6 @@ SmallVector loopRanges; Location loc = producer.getLoc(); auto zero = b.create(loc, 0); - auto one = b.create(loc, 1); for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { auto shapeDim = getShapeDefiningLoopRange(producer, i); @@ -125,14 +124,14 @@ sizeBounds.push_back(dim); auto it = fusedLoopsAndRanges.find(i); if (it != fusedLoopsAndRanges.end()) { - ivs.push_back(it->second.offset); - tileSizes.push_back(it->second.size); + ivs.push_back(materializeOpFoldResult(b, loc, it->second.offset)); + tileSizes.push_back(materializeOpFoldResult(b, loc, it->second.size)); loopRanges.push_back(it->second); LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " << loopRanges.back() << "\n"); } else { tileSizes.push_back(zero); - loopRanges.push_back(Range{zero, dim, one}); + loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)}); LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " << loopRanges.back() << "\n"); } @@ -168,8 +167,9 @@ // Shift all IndexOp results by the tile offset. SmallVector allIvs; - llvm::transform(loopRanges, std::back_inserter(allIvs), - [](Range range) { return range.offset; }); + llvm::transform(loopRanges, std::back_inserter(allIvs), [&](Range range) { + return materializeOpFoldResult(b, loc, range.offset); + }); offsetIndices(b, clonedOp, allIvs); return clonedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -143,8 +143,9 @@ // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. SmallVector producerLoopBounds; llvm::transform(producerOp.createLoopRanges(b, loc), - std::back_inserter(producerLoopBounds), - [](Range range) { return range.size; }); + std::back_inserter(producerLoopBounds), [&](Range range) { + return materializeOpFoldResult(b, loc, range.size); + }); SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); // Tile the producer operands given the `sliceOp` ranges. Iterate the @@ -157,8 +158,10 @@ for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { int64_t tiledSliceDim = std::get<0>(it); int64_t tiledProducerLoop = std::get<1>(it); - tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; - tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; + tileIvs[tiledProducerLoop] = + materializeOpFoldResult(b, loc, sliceOpRanges[tiledSliceDim].offset); + tileSizes[tiledProducerLoop] = + materializeOpFoldResult(b, loc, sliceOpRanges[tiledSliceDim].size); allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; } erase_value(tileIvs, nullptr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -223,14 +223,20 @@ if (droppedDims[en.index()]) continue; auto rangeValue = en.value(); - // Try to extract a tight constant. + // Try to extract a tight constant. If the size is known statically, no need + // to look for the bound. LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); - FailureOr upperBound = - getConstantUpperBoundForIndex(rangeValue.size); - Value size = - failed(upperBound) - ? rangeValue.size - : b.create(loc, upperBound.value()); + Value size; + if (auto attr = rangeValue.size.dyn_cast()) { + size = materializeOpFoldResult(b, loc, rangeValue.size); + } else { + Value materializedSize = materializeOpFoldResult(b, loc, rangeValue.size); + FailureOr upperBound = + getConstantUpperBoundForIndex(materializedSize); + size = failed(upperBound) + ? materializedSize + : b.create(loc, upperBound.getValue()); + } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); fullSizes.push_back(size); partialSizes.push_back( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -74,12 +74,10 @@ if (dimension >= iterationSpace.size()) return std::make_pair(op, TilingInterface()); - SmallVector offsets = - getAsOpFoldResult(llvm::to_vector(llvm::map_range( - iterationSpace, [](const Range &range) { return range.offset; }))); - SmallVector sizes = - getAsOpFoldResult(llvm::to_vector(llvm::map_range( - iterationSpace, [](const Range &range) { return range.size; }))); + SmallVector offsets = llvm::to_vector(llvm::map_range( + iterationSpace, [](const Range &range) { return range.offset; })); + SmallVector sizes = llvm::to_vector(llvm::map_range( + iterationSpace, [](const Range &range) { return range.size; })); // Adjust the split point so that it doesn't overflow the size. AffineExpr d0, d1, d2; @@ -105,7 +103,7 @@ TilingInterface firstPart = createSplitPart( rewriter, op.getLoc(), op, offsets, sizes, op.getDestinationOperands(rewriter), dimension, minSplitPoint, - getAsOpFoldResult(iterationSpace[dimension].offset), firstResults); + iterationSpace[dimension].offset, firstResults); // Need to pretend that the original op now takes as operands firstResults, // otherwise tiling interface implementation will take the wrong value to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -66,8 +66,7 @@ // Create a new range with the applied tile sizes. SmallVector res; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) - res.push_back(Range{b.create(loc, 0), - shapeSizes[idx], tileSizes[idx]}); + res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]}); return std::make_tuple(res, loopIndexToRangeIndex); } @@ -567,10 +566,12 @@ SmallVector ranges = tilingInterface.getIterationDomain(builder); SmallVector lbs, dims, allDims, steps; for (int64_t i = 0; i < rank; ++i) { - allDims.push_back(ranges[i].size); + Value materializedSize = + materializeOpFoldResult(builder, loc, ranges[i].size); + allDims.push_back(materializedSize); if (!isZero(tileSizes[i])) { - lbs.push_back(ranges[i].offset); - dims.push_back(ranges[i].size); + lbs.push_back(materializeOpFoldResult(builder, loc, ranges[i].offset)); + dims.push_back(materializedSize); steps.push_back(tileSizes[i]); } } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -129,13 +129,14 @@ /// Given a list of subview ranges, extract individual values for lower, upper /// bounds and steps and put them into the corresponding vectors. -static void unpackRanges(ArrayRef ranges, SmallVectorImpl &lbs, +static void unpackRanges(OpBuilder &builder, Location loc, + ArrayRef ranges, SmallVectorImpl &lbs, SmallVectorImpl &ubs, SmallVectorImpl &steps) { for (Range range : ranges) { - lbs.emplace_back(range.offset); - ubs.emplace_back(range.size); - steps.emplace_back(range.stride); + lbs.emplace_back(materializeOpFoldResult(builder, loc, range.offset)); + ubs.emplace_back(materializeOpFoldResult(builder, loc, range.size)); + steps.emplace_back(materializeOpFoldResult(builder, loc, range.stride)); } } @@ -524,7 +525,7 @@ } SmallVector lbs, ubs, steps; - unpackRanges(loopRanges, lbs, ubs, steps); + unpackRanges(b, loc, loopRanges, lbs, ubs, steps); LoopNest loopNest = mlir::scf::buildLoopNest( b, loc, lbs, ubs, steps, iterArgInitValues, [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { @@ -567,7 +568,7 @@ SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; - unpackRanges(loopRanges, lbs, ubs, steps); + unpackRanges(b, loc, loopRanges, lbs, ubs, steps); // Affine loops require constant steps. SmallVector constantSteps; @@ -744,7 +745,7 @@ stepsStorage.reserve(numLoops); // Get the loop lb, ub, and step. - unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage); + unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); // Modify the lb, ub, and step based on the distribution options. SmallVector distributionMethod; @@ -986,6 +987,12 @@ return builder.create(attr.getValue().getSExtValue()); } +Value materializeOpFoldResult(OpBuilder &builder, Location loc, + OpFoldResult opFoldResult) { + ImplicitLocOpBuilder b(loc, builder); + return materializeOpFoldResult(b, opFoldResult); +} + SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef valuesToTile, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -117,23 +118,25 @@ AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); for (auto loopRange : llvm::enumerate(loopRanges)) { + Value offset = + getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); + Value size = + getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); // No loops if tile size is zero. Set offset and size to the loop // offset and size. if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { - offsets[loopRange.index()] = loopRange.value().offset; - sizes[loopRange.index()] = loopRange.value().size; + offsets[loopRange.index()] = offset; + sizes[loopRange.index()] = size; continue; } auto loop = builder.create( - loc, loopRange.value().offset, loopRange.value().size, - tileSizeVals[loopRange.index()], ValueRange{}, + loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) { Value boundedTileSize = builder.create( bodyLoc, minMap, - ValueRange{iv, tileSizeVals[loopRange.index()], - loopRange.value().size}); + ValueRange{iv, tileSizeVals[loopRange.index()], size}); sizes[loopRange.index()] = boundedTileSize; builder.create(loc); });