diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -29,6 +29,7 @@ struct TiledLinalgOp { LinalgOp op; SmallVector loops; + SmallVector tensorResults; }; struct TiledAndFusedLinalgOps { @@ -371,8 +372,9 @@ LinalgTilingOptions options, LinalgMarker marker = LinalgMarker(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &tensorResults) const; private: /// LinalgTransformMarker handles special attribute manipulations. @@ -390,9 +392,14 @@ marker, benefit) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (failed(LinalgBaseTilingPattern::matchAndRewrite(op, rewriter))) + SmallVector tensorResults; + if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter, + tensorResults))) return failure(); - rewriter.eraseOp(op); + if (tensorResults.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, tensorResults); return success(); } }; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -95,17 +95,18 @@ unsigned consumerIdx, OperationFolder *folder = nullptr); -/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying -/// the inverse, concatenated loopToOperandRangeMaps to this list allows the -/// derivation of loop ranges for any linalgOp. -SmallVector getViewSizes(OpBuilder &builder, LinalgOp linalgOp); +/// Returns the linearized list of all shape dimensions in a `linalgOp`. +/// Applying the inverse, concatenated loopToOperandRangeMaps to this list +/// allows the derivation of loop ranges for any linalgOp. +SmallVector getShapeSizes(OpBuilder &builder, LinalgOp linalgOp); template -SmallVector getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) { - return getViewSizes(builder, cast(linalgOp.getOperation())); +SmallVector getShapeSizes(OpBuilder &builder, ConcreteOpTy linalgOp) { + return getShapeSizes(builder, + cast(linalgOp.getOperation())); } /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the -/// concatenated indexing maps to the result of `getViewSizes`. Returns None if +/// concatenated indexing maps to the result of `getShapeSizes`. Returns None if /// the bounds computation fails. Optional> getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, @@ -119,11 +120,6 @@ AffineMap map, ValueRange values, OperationFolder *folder = nullptr); -/// Returns all the operands of `linalgOp` that are not views. -/// Asserts that these operands are value types to allow transformations like -/// tiling to just use the values when cloning `linalgOp`. -SmallVector getAssumedNonViewOperands(LinalgOp linalgOp); - /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -315,7 +315,7 @@ /// source memref. This is useful to to fold a memref_cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that /// may consume the results of memref_cast operations. Such foldable memref_cast -/// operations are typically inserted as `view` and `subview` ops are +/// operations are typically inserted as `view` and `subview` ops and are /// canonicalized, to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -199,6 +199,11 @@ if (isTopLevelValue(dimOp.memrefOrTensor())) return true; + // Conservatively handle remaining BlockArguments as non-valid symbols. + // E.g. scf.for iterArgs. + if (dimOp.memrefOrTensor().isa()) + return false; + // The dim op is also okay if its operand memref/tensor is a view/subview // whose corresponding size is a valid symbol. Optional index = dimOp.getConstantIndex(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -97,7 +97,7 @@ clonedViews.push_back( b.create(loc, view, offsets, sizes, strides)); } - auto operands = getAssumedNonViewOperands(op); + auto operands = op.getAssumedNonShapedOperands(); clonedViews.append(operands.begin(), operands.end()); Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -508,10 +508,10 @@ linalgOp.indexing_maps().template getAsRange(); auto maps = llvm::to_vector<8>( llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - SmallVector sizes = getViewSizes(builder, linalgOp); + SmallVector sizes = getShapeSizes(builder, linalgOp); AffineMap map = concatAffineMaps(maps); auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), - map, getViewSizes(builder, linalgOp)); + map, getShapeSizes(builder, linalgOp)); SmallVector allIvs; GenerateLoopNest::doit( loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -56,18 +56,17 @@ // indices of newly created loops. static std::tuple, LoopIndexToRangeIndexMap> makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, - ArrayRef allTileSizes) { + ValueRange allShapeSizes, ValueRange allTileSizes) { assert(allTileSizes.size() == map.getNumResults()); - // Apply `map` to get view sizes in loop order. - auto viewSizes = applyMapToValues(b, loc, map, allViewSizes); + // Apply `map` to get shape sizes in loop order. + auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); // Traverse the tile sizes, which are in loop order, erase zeros everywhere. LoopIndexToRangeIndexMap loopIndexToRangeIndex; for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { if (isZero(tileSizes[idx - zerosCount])) { - viewSizes.erase(viewSizes.begin() + idx - zerosCount); + shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); tileSizes.erase(tileSizes.begin() + idx - zerosCount); ++zerosCount; continue; @@ -78,11 +77,10 @@ // Create a new range with the applied tile sizes. SmallVector res; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) - res.push_back(Range{std_constant_index(0), viewSizes[idx], - tileSizes[idx]}); + res.push_back( + Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]}); return std::make_tuple(res, loopIndexToRangeIndex); } - namespace { // Helper visitor to determine whether an AffineExpr is tiled. @@ -94,7 +92,7 @@ // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes) : isTiled(false), tileSizes(tileSizes) {} + TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { isTiled |= !isZero(tileSizes[expr.getPosition()]); @@ -107,7 +105,7 @@ "nonpositive multiplying coefficient"); } bool isTiled; - ArrayRef tileSizes; + ValueRange tileSizes; }; } // namespace @@ -166,7 +164,6 @@ static void transformIndexedGenericOpIndices( OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto indexedGenericOp = dyn_cast(op.getOperation()); if (!indexedGenericOp) return; @@ -203,7 +200,7 @@ } } -static bool isTiled(AffineExpr expr, ArrayRef tileSizes) { +static bool isTiled(AffineExpr expr, ValueRange tileSizes) { if (!expr) return false; TileCheck t(tileSizes); @@ -211,9 +208,8 @@ return t.isTiled; } -// Checks whether the view with index `viewIndex` within `linalgOp` varies with -// respect to a non-zero `tileSize`. -static bool isTiled(AffineMap map, ArrayRef tileSizes) { +// Checks whether the `map varies with respect to a non-zero `tileSize`. +static bool isTiled(AffineMap map, ValueRange tileSizes) { if (!map) return false; for (unsigned r = 0; r < map.getNumResults(); ++r) @@ -222,13 +218,11 @@ return false; } -static SmallVector makeTiledViews(OpBuilder &b, Location loc, - LinalgOp linalgOp, AffineMap map, - ArrayRef ivs, - ArrayRef tileSizes, - ArrayRef allViewSizes) { - assert(linalgOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); +static SmallVector +makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, + ValueRange operands, AffineMap map, ValueRange ivs, + ValueRange tileSizes, ValueRange allShapeSizes) { + assert(operands.size() == linalgOp.getShapedOperands().size()); assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -236,37 +230,34 @@ using namespace edsc::op; - auto viewSizes = applyMapToValues(b, loc, map, allViewSizes); + auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); // Construct (potentially temporary) mins and maxes on which to apply maps - // that define tile subviews. - SmallVector lbs, subViewSizes; + // that define tile subshapes. + SmallVector lbs, subShapeSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0)); // Before composing, we need to make range a closed interval. - Value size = isTiled ? tileSizes[idx] : viewSizes[idx]; - subViewSizes.push_back(size - std_constant_index(1)); + Value size = isTiled ? tileSizes[idx] : shapeSizes[idx]; + subShapeSizes.push_back(size - std_constant_index(1)); } auto *op = linalgOp.getOperation(); SmallVector res; res.reserve(op->getNumOperands()); - auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin(); - for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); - ++viewIndex) { - Value view = *(viewIteratorBegin + viewIndex); - auto viewType = view.getType().cast(); - unsigned rank = viewType.getRank(); - auto mapAttr = linalgOp.indexing_maps()[viewIndex]; - auto map = mapAttr.cast().getValue(); - // If the view is not tiled, we can use it as is. + for (auto en : llvm::enumerate(operands)) { + Value shapedOp = en.value(); + ShapedType shapedType = shapedOp.getType().cast(); + unsigned rank = shapedType.getRank(); + AffineMap map = linalgOp.getIndexingMap(en.index()); + // If the shape is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { - res.push_back(view); + res.push_back(shapedOp); continue; } - // Construct a new subview for the tile. + // Construct a new subview / subtensor for the tile. SmallVector offsets, sizes, strides; offsets.reserve(rank); sizes.reserve(rank); @@ -274,27 +265,27 @@ for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { offsets.push_back(std_constant_index(0)); - sizes.push_back(std_dim(view, r)); + sizes.push_back(std_dim(shapedOp, r)); strides.push_back(std_constant_index(1)); continue; } // Tiling creates a new slice at the proper index, the slice step is 1 - // (i.e. the slice view does not subsample, stepping occurs in the loop). + // (i.e. the op does not subsample, stepping occurs in the loop). auto m = map.getSubMap({r}); auto offset = applyMapToValues(b, loc, m, lbs).front(); offsets.push_back(offset); - auto closedIntSize = applyMapToValues(b, loc, m, subViewSizes).front(); + auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front(); // Resulting size needs to be made half open interval again. auto size = closedIntSize + std_constant_index(1); - // The size of the subview should be trimmed to avoid out-of-bounds - // accesses, unless we statically know the subview size divides the view - // size evenly. - int64_t viewSize = viewType.getDimSize(r); + // The size of the subview / subtensor should be trimmed to avoid + // out-of-bounds accesses, unless we statically know the subshape size + // divides the shape size evenly. + int64_t shapeSize = shapedType.getDimSize(r); auto sizeCst = size.getDefiningOp(); - if (ShapedType::isDynamic(viewSize) || !sizeCst || - (viewSize % sizeCst.getValue()) != 0) { + if (ShapedType::isDynamic(shapeSize) || !sizeCst || + (shapeSize % sizeCst.getValue()) != 0) { // Compute min(size, dim - offset) to avoid out-of-bounds accesses. auto minMap = AffineMap::get( /*dimCount=*/3, /*symbolCount=*/0, @@ -302,7 +293,7 @@ getAffineDimExpr(/*position=*/1, b.getContext()) - getAffineDimExpr(/*position=*/2, b.getContext())}, b.getContext()); - auto d = std_dim(view, r); + auto d = std_dim(shapedOp, r); size = affine_min(b.getIndexType(), minMap, ValueRange{size, d, offset}); } @@ -311,7 +302,12 @@ strides.push_back(std_constant_index(1)); } - res.push_back(b.create(loc, view, offsets, sizes, strides)); + if (shapedType.isa()) + res.push_back( + b.create(loc, shapedOp, offsets, sizes, strides)); + else + res.push_back( + b.create(loc, shapedOp, offsets, sizes, strides)); } return res; @@ -319,7 +315,7 @@ template static Optional -tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, +tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, const LinalgTilingOptions &options) { auto nLoops = op.getNumLoops(); // Initial tile sizes may be too big, only take the first nLoops. @@ -336,20 +332,20 @@ } // 1. Build the tiled loop ranges. - auto allViewSizes = getViewSizes(b, op); + auto allShapeSizes = getShapeSizes(b, op); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (asserted in the inverse calculation). auto mapsRange = op.indexing_maps().getAsRange(); auto maps = llvm::to_vector<8>( llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps)); - if (!viewSizesToLoopsMap) + auto shapeSizesToLoopsMap = inversePermutation(concatAffineMaps(maps)); + if (!shapeSizesToLoopsMap) return llvm::None; SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( - b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes); + b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); SmallVector iteratorTypes; for (auto attr : enumerate(op.iterator_types().cast().getValue())) { @@ -381,29 +377,79 @@ // 2. Create the tiled loops. LinalgOp res = op; - SmallVector ivs; + SmallVector ivs, tensorResults; + auto initTensors = op.getInitTensors(); GenerateLoopNest::doit( - loopRanges, /*iterArgInitValues*/ {}, iteratorTypes, + loopRanges, /*iterArgInitValues*/ initTensors, iteratorTypes, [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); ivs.assign(localIvs.begin(), localIvs.end()); - SmallVector ivValues(ivs.begin(), ivs.end()); - // If we have to apply a permutation to the tiled loop nest, we have to - // reorder the induction variables This permutation is the right one - // assuming that loopRanges have previously been permuted by - // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation - // of that one: (d0,d1,d2)->(d2,d0,d1) + // When an `interchangeVector` is present, it has been applied to the + // loop ranges and the iterator types. Apply its inverse to the + // resulting loop `ivs` to match the op definition. + SmallVector interchangedIvs; if (!options.interchangeVector.empty()) - ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); - - auto views = makeTiledViews(b, loc, op, viewSizesToLoopsMap, ivValues, - tileSizes, allViewSizes); - auto operands = getAssumedNonViewOperands(op); - views.append(operands.begin(), operands.end()); - res = op.clone(b, loc, /*resultTypes*/ {}, views); - return scf::ValueVector{}; + interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); + else + interchangedIvs.assign(ivs.begin(), ivs.end()); + + unsigned numInitTensors = op.getNumInitTensors(); + assert(numInitTensors == iterArgs.size()); + // This uses knowledge about position of the init tensor in the list + // of operands. + // TODO: InterfaceAdaptor ? + auto operands = llvm::to_vector<4>(op.getShapedOperands()); + std::copy(iterArgs.begin(), iterArgs.end(), + operands.begin() + op.getNumInputsAndOutputBuffers()); + + auto shapedValues = + makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap, + interchangedIvs, tileSizes, allShapeSizes); + SmallVector tiledOperands(shapedValues); + auto nonShapedOperands = op.getAssumedNonShapedOperands(); + tiledOperands.append(nonShapedOperands.begin(), + nonShapedOperands.end()); + + // If LinalgOp has results, they must all be tied to init tensors. + // We enforce this to ensure all tiled ops have been rewritten in + // "init tensor" form. This ensures tiling has anchor values into which + // to subtensor / subtensor_insert. Otherwise tiling would need to + // allocate which is not acceptable. + // This would not be the case with a special terminator op that + // generates the whole tensor (instead of inserting a subtensor). But + // the generator-based abstraction has other issues. + assert(op.getNumInitTensors() == op.getOperation()->getNumResults() && + "expected same number of init tensors as number of results"); + + // Handle init tensor operands. + // This uses knowledge about position of the init tensor in the list + // of operands. + // TODO: InterfaceAdaptor ? + SmallVector resultTensorTypes; + for (auto idx : llvm::seq(0, op.getNumInitTensors())) + resultTensorTypes.push_back( + tiledOperands[op.getNumInputsAndOutputBuffers() + idx].getType()); + + res = op.clone(b, loc, resultTensorTypes, tiledOperands); + + // Insert a subtensor_insert for each init subtensor. + for (unsigned idx = 0, e = op.getNumInitTensors(); idx != e; ++idx) { + Value initTensor = + tiledOperands[op.getNumInputsAndOutputBuffers() + idx]; + if (auto subtensor = initTensor.getDefiningOp()) { + tensorResults.push_back(b.create( + loc, subtensor.source().getType(), + res.getOperation()->getResult(idx), subtensor.source(), + subtensor.offsets(), subtensor.sizes(), subtensor.strides(), + subtensor.static_offsets(), subtensor.static_sizes(), + subtensor.static_strides())); + } else { + tensorResults.push_back(res.getOperation()->getResult(idx)); + } + } + return scf::ValueVector(tensorResults.begin(), tensorResults.end()); }, options.distribution); @@ -423,7 +469,16 @@ loops.push_back(nullptr); } } - return TiledLinalgOp{res, loops}; + + // 5. Get the tensor results from the outermost loop if available. Otherwise + // use the previously captured `tensorResults`. + Operation *outermostLoop = nullptr; + for (Operation *loop : loops) + if ((outermostLoop = loop)) + break; + + return TiledLinalgOp{ + res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; } template @@ -433,7 +488,6 @@ b.setInsertionPoint(op); ScopedContext scope(b, op.getLoc()); - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -514,7 +568,9 @@ scf::ForOp::getCanonicalizationPatterns(patterns, ctx); scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); + SubTensorOp::getCanonicalizationPatterns(patterns, ctx); SubViewOp::getCanonicalizationPatterns(patterns, ctx); + TensorCastOp::getCanonicalizationPatterns(patterns, ctx); ViewOp::getCanonicalizationPatterns(patterns, ctx); CanonicalizationPatternList< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -111,19 +111,34 @@ : RewritePattern(opName, {}, benefit, context), marker(marker), options(options) {} -LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { +LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( + Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &tensorResults) const { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); + // If LinalgOp has results, they must all be tied to init tensors. + // We enforce this to ensure all tiled ops have been rewritten in + // "init tensor" form. This ensures tiling has anchor values into which to + // subtensor / subtensor_insert. Otherwise tiling would need to allocate which + // is not acceptable. + // This would not be the case with a special terminator op that generates the + // whole tensor (instead of inserting a subtensor). But the generator-based + // abstraction has other issues. + if (linalgOp.getNumInitTensors() != linalgOp.getOperation()->getNumResults()) + return failure(); + Optional res = tileLinalgOp(rewriter, linalgOp, options); if (!res) return failure(); + // Return relevant information to derived pattern. + tensorResults = res->tensorResults; + // New marker if specified. marker.replaceLinalgMarker(rewriter, res->op.getOperation()); return success(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -85,26 +85,6 @@ return res; } -/// Returns all the operands of `linalgOp` that are not views. -/// Asserts that these operands are value types to allow transformations like -/// tiling to just use the values when cloning `linalgOp`. -SmallVector -mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) { - auto *op = linalgOp.getOperation(); - unsigned numViews = linalgOp.getNumInputsAndOutputs(); - unsigned nOperands = op->getNumOperands() - numViews; - SmallVector res; - res.reserve(nOperands); - for (unsigned i = 0; i < nOperands; ++i) { - res.push_back(op->getOperand(numViews + i)); - auto t = res.back().getType(); - (void)t; - assert((t.isSignlessIntOrIndexOrFloat() || t.isa()) && - "expected scalar or vector type"); - } - return res; -} - bool mlir::linalg::isParallelIteratorType(Attribute attr) { if (auto strAttr = attr.dyn_cast()) { return strAttr.getValue() == getParallelIteratorTypeName(); @@ -147,12 +127,12 @@ namespace linalg { /// Return the linearized list of all view dimensions in a linalgOp. -SmallVector getViewSizes(OpBuilder &builder, LinalgOp linalgOp) { +SmallVector getShapeSizes(OpBuilder &builder, LinalgOp linalgOp) { auto loc = linalgOp.getLoc(); SmallVector res; SmallVector ranks; - for (auto v : linalgOp.getInputsAndOutputBuffers()) { - MemRefType t = v.getType().template cast(); + for (Value v : linalgOp.getShapedOperands()) { + ShapedType t = v.getType().template cast(); ranks.push_back(t.getRank()); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(builder.create(loc, v, i)); @@ -181,7 +161,7 @@ Optional> getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) { - SmallVector viewSizes = getViewSizes(builder, linalgOp); + SmallVector viewSizes = getShapeSizes(builder, linalgOp); AffineMap invertedMap = inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); if (!invertedMap) diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s + +// CHECK-LABEL: func @matmul_tensors( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +func @matmul_tensors( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor to tensor +// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor to tensor +// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor to tensor +// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor, tensor) +// CHECK-SAME: init(%[[sTC]] : tensor) -> tensor +// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD]] : tensor +// CHECK: scf.yield %[[TD2]] : tensor +// CHECK: scf.yield %[[TD1]] : tensor + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + init(%arg2: tensor) + -> tensor + +// CHECK: return %[[TD0]] : tensor + return %0 : tensor +}