diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -492,7 +492,7 @@ AttrSizedOperandSegments, DeclareOpInterfaceMethods, RecursiveSideEffects, - SingleBlockImplicitTerminator<"linalg::YieldOp"> + SingleBlockImplicitTerminator<"linalg::TiledYieldOp"> ]> { let summary = "Linalg tiled loop operation"; let description = [{ @@ -509,7 +509,7 @@ every tensor argument of TiledLoopOp. The body region must contain exactly one block that terminates with - `linalg.yield` with the operands resulting from `insert_slice` operations. + `linalg.tiled_yield`. Example: @@ -528,9 +528,7 @@ %result_sub = linalg.generic ... - %result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1] - : tensor into tensor<24x64xi8> - linalg.yield %result : tensor<24x64xi8> + linalg.tiled_yield %result_sub to %out_sub : tensor } ``` @@ -540,7 +538,7 @@ every memref argument of TiledLoopOp. The body region must contain exactly one block that terminates with - `linalg.yield` with no operands. + `linalg.tiled_yield` with no operands. Example: @@ -558,7 +556,7 @@ : memref<24x64xi8> to memref %result_sub = linalg.generic ... - linalg.yield + linalg.tiled_yield } ``` }]; @@ -747,6 +745,29 @@ let hasFolder = 1; } +def Linalg_TiledYieldOp : Linalg_Op<"tiled_yield", + [NoSideEffect, ReturnLike, Terminator, SameVariadicOperandSize]>, + Arguments<(ins Variadic:$tiles, Variadic:$outputs)> { + let summary = "Linalg tiled yield operation"; + let description = [{ + `linalg.tiled_yield` is a special terminator operation for the block inside + the region of `linalg.tiled_loop` op. It updates the part of the enclosing + `linalg.tiled_loop` result specifies by the `outputs` operand with the + values from the `tiles` operand. + + Example: + + ```mlir + linalg.tiled_loop ... outs(%out_ = %out : tensor) { + %output = tensor.extract_slice %out_... // or %output = %out_ + %tile = "some_computation" + linalg.tiled_yield %tile in %output : tensor + ``` + }]; + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let hasCanonicalizer = 1; +} + def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>, Arguments<(ins Confined]>:$dim)>, Results<(outs Index:$result)> { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1497,30 +1497,6 @@ return success(); } - if (auto tiledLoopOp = dyn_cast(parentOp)) { - // Check if output args with tensor types match results types. - SmallVector tensorOuts; - llvm::copy_if( - tiledLoopOp.outputs(), std::back_inserter(tensorOuts), - [&](Value out) { return out.getType().isa(); }); - if (tensorOuts.size() != op.values().size()) - return op.emitOpError("expected number of tensor output args = ") - << tensorOuts.size() << " to match the number of yield operands = " - << op.values().size(); - - TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); - for (auto &item : - llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { - Type outType, resultType; - unsigned index = item.index(); - std::tie(outType, resultType) = item.value(); - if (outType != resultType) - return op.emitOpError("expected yield operand ") - << index << " with type = " << resultType - << " to match output arg type = " << outType; - } - return success(); - } return op.emitOpError("expected parent op with LinalgOp interface"); } @@ -1892,11 +1868,11 @@ return failure(); Block *block = tiledLoop.getBody(); - auto yieldOp = cast(block->getTerminator()); + auto yieldOp = cast(block->getTerminator()); // Match the pattern and collect output buffers that will replace the output // tensors and also the ops that will be ignored when cloning the body. - SmallVector newOutputOperands, newYieldArgs; + SmallVector newOutputOperands, newYieldTileArgs, newYieldOutArgs; int resultId = 0; // Store ids of the corresponding old and new output operands. SmallVector oldOutputIdToNew(tiledLoop.outputs().size(), @@ -1917,13 +1893,15 @@ continue; } Value result = tiledLoop.getResult(resultId); - Value yieldArg = yieldOp.getOperand(resultId); - if (yieldArg != outRegionArg || !result.use_empty()) { + Value yieldTileArg = yieldOp.tiles()[resultId]; + Value yieldOutArg = yieldOp.outputs()[resultId]; + if (yieldTileArg != outRegionArg || !result.use_empty()) { oldOutputIdToNew[index] = newOutputOperands.size(); - oldResultIdToNew[resultId] = newYieldArgs.size(); + oldResultIdToNew[resultId] = newYieldTileArgs.size(); resultReplacement[resultId] = out; newOutputOperands.push_back(out); - newYieldArgs.push_back(yieldArg); + newYieldTileArgs.push_back(yieldTileArg); + newYieldOutArgs.push_back(yieldOutArg); } ++resultId; } @@ -1952,9 +1930,12 @@ OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); for (auto &op : tiledLoop.getBody()->without_terminator()) innerBuilder.clone(op, bvm); - innerBuilder.create( - loc, llvm::to_vector<2>(llvm::map_range( - newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); + innerBuilder.create( + loc, + llvm::to_vector<2>(llvm::map_range( + newYieldTileArgs, [&](Value arg) { return bvm.lookup(arg); })), + llvm::to_vector<2>(llvm::map_range( + newYieldOutArgs, [&](Value arg) { return bvm.lookup(arg); }))); for (const auto &en : llvm::enumerate(oldResultIdToNew)) if (en.value() != kNoMatch) @@ -1976,6 +1957,146 @@ return foldMemRefCastInTiledLoopOp(*this); } +//===----------------------------------------------------------------------===// +// TiledYieldOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, TiledYieldOp op) { + p << op.getOperationName(); + + if (!op.tiles().empty()) { + llvm::interleaveComma(llvm::zip(op.tiles(), op.outputs()), p, [&](auto it) { + p << ' ' << std::get<0>(it) << " in " << std::get<1>(it) << " : " + << std::get<1>(it).getType(); + }); + } + p.printOptionalAttrDict(op->getAttrs()); +} + +static ParseResult parseTiledYieldOp(OpAsmParser &parser, + OperationState &result) { + SmallVector tiles, outputs; + SmallVector types; + + OpAsmParser::OperandType tile; + while (parser.parseOptionalOperand(tile).hasValue()) { + Type type; + OpAsmParser::OperandType output; + if (parser.parseKeyword("in") || parser.parseOperand(output) || + parser.parseColon() || parser.parseType(type)) + return failure(); + tiles.push_back(tile); + outputs.push_back(output); + types.push_back(type); + parser.parseOptionalComma(); + } + llvm::SMLoc loc = parser.getCurrentLocation(); + if (parser.resolveOperands(tiles, types, loc, result.operands) || + parser.resolveOperands(outputs, types, loc, result.operands)) + return failure(); + + // Parse optional attributes. + parser.parseOptionalAttrDict(result.attributes); + + return success(); +} + +static LogicalResult verify(TiledYieldOp op) { + // Check if output args with tensor types match results types. + auto loop = op->getParentOfType(); + SmallVector loopTensorOuts; + llvm::copy_if( + loop.outputs(), std::back_inserter(loopTensorOuts), + [&](Value out) { return out.getType().isa(); }); + if (loopTensorOuts.size() != op.tiles().size()) + return op.emitOpError("expected number of tensor output args = ") + << loopTensorOuts.size() + << " to match the number of yield operands = " << op.tiles().size(); + + // Check if the `tiles` args types match the `outputs` args types. + SmallVector loopTensorOutsBlockArgs; + llvm::copy_if( + loop.getRegionOutputArgs(), std::back_inserter(loopTensorOutsBlockArgs), + [&](Value out) { return out.getType().isa(); }); + for (auto en : llvm::enumerate( + llvm::zip(op.tiles(), op.outputs(), loopTensorOutsBlockArgs))) { + size_t index = en.index(); + Type tileType = std::get<0>(en.value()).getType(); + Value yieldOut = std::get<1>(en.value()); + Type yieldOutType = yieldOut.getType(); + + if (tileType != yieldOutType) + return op.emitOpError("expected tile operand with type = ") + << tileType << " to match output type = " << yieldOutType; + + // Check if yieldOut is either an output bbArg or a slice of it. + Value src = yieldOut; + if (auto extractSlice = llvm::dyn_cast_or_null( + yieldOut.getDefiningOp())) + src = extractSlice.source(); + + Value loopBlockArg = std::get<2>(en.value()); + if (src != loopBlockArg) + return op.emitOpError("expected output ") + << index << " to be a subset of the corresponding block argument"; + } + return success(); +} + +namespace { +/// Pattern to rewrite TiledYieldOp with tensor::CastOp arguments. +/// +/// Example: +/// ``` +/// %TILE_CAST = tensor.cast %TILE : tensor<16x16xf32> to tensor +/// %OUT_CAST = tensor.cast %OUT : tensor<16x16xf32> to tensor +/// linalg.tiled_yield %TILE_CAST in %OUT_CAST : tensor +/// ``` +/// is rewritten into: +/// ``` +/// linalg.tiled_yield %TILE in %OUT : tensor<16x16xf32> +/// ``` +class TensorYieldOpCastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TiledYieldOp tiledYieldOp, + PatternRewriter &rewriter) const override { + bool foundTensorCasts = false; + SmallVector tiles, outputs; + for (auto item : llvm::zip(tiledYieldOp.tiles(), tiledYieldOp.outputs())) { + Value tile, output; + std::tie(tile, output) = item; + auto tileCast = tile.getDefiningOp(); + if (!tileCast) + continue; + + auto outputCast = output.getDefiningOp(); + if (!outputCast) + continue; + + if (tileCast.source().getType() != outputCast.source().getType()) { + tiles.push_back(tile); + outputs.push_back(output); + continue; + } + foundTensorCasts = true; + tiles.push_back(tileCast.source()); + outputs.push_back(outputCast.source()); + } + if (!foundTensorCasts) + return failure(); + rewriter.replaceOpWithNewOp(tiledYieldOp, tiles, outputs); + return success(); + } +}; +} // namespace + +void TiledYieldOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // IndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -372,6 +372,7 @@ ReturnOp, TiledLoopOp, VectorTransferOpInterface, + linalg::TiledYieldOp, linalg::YieldOp, scf::YieldOp>(op) // clang-format on @@ -519,7 +520,7 @@ return None; return TypeSwitch(opOperand.getOwner()) // These terminators legitimately have no result. - .Case( + .Case( [&](auto op) { return OpResult(); }) // ConstantOp is never inplaceable. .Case([&](ConstantOp op) { return op->getResult(0); }) @@ -570,6 +571,11 @@ if (auto linalgOp = dyn_cast(opOperand.getOwner())) return linalgOp.isInputTensor(&opOperand) || linalgOp.isInitTensor(&opOperand); + // This is questionable. Should we consider TiledYieldOp as an op that + // bufferizes to "read" for the `tile` args and to "write" for the `output` + // args? + if (isa(opOperand.getOwner())) + return false; // All other cases are considered to bufferize to memory reads. // In particular, terminators are often the last use and need to be considered // as reads to return the proper value and avoid WAW clobbers. @@ -583,7 +589,8 @@ bufferizesToMemoryWrite(OpOperand &opOperand, InPlaceSpec inPlaceSpec = InPlaceSpec::None) { // These terminators are not writes. - if (isa(opOperand.getOwner())) + if (isa( + opOperand.getOwner())) return false; // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses // may. @@ -2110,9 +2117,6 @@ // No tensors -> success. if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) return success(); - // linalg::YieldOp nested under TiledLoop must just canonicalize. - if (yieldOp->getParentOfType()) - return success(); llvm_unreachable("unexpected yieldOp"); } @@ -2131,6 +2135,15 @@ extractOp.replaceAllUsesWith(l); return success(); } + +/// Bufferization for linalg::TiledYieldOp just results in later +/// canonicalization. +static LogicalResult bufferize(OpBuilder &b, linalg::TiledYieldOp yieldOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + return success(); +} + //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -2332,6 +2345,7 @@ TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp, + linalg::TiledYieldOp, scf::YieldOp>([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -164,6 +164,47 @@ sliceOp.static_sizes(), sliceOp.static_strides()); } +template +static SmallVector +collectLoopYieldArgs(OpBuilder &b, LinalgOp clonedOp, + ArrayRef tiledOperands, + SmallVectorImpl &tensorResults) { + + Location loc = clonedOp.getLoc(); + SmallVector yieldArgs; + unsigned resultIdx = 0; + for (OpOperand *opOperand : clonedOp.getOutputTensorOperands()) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; + // Insert a insert_slice for each output tensor. + if (auto sliceOp = outputTensor.getDefiningOp()) { + yieldArgs.push_back(insertSliceIntoTensor( + b, loc, sliceOp, clonedOp->getResult(resultIdx), sliceOp.source())); + } else { + yieldArgs.push_back(clonedOp->getResult(resultIdx)); + } + ++resultIdx; + } + tensorResults = yieldArgs; + return yieldArgs; +} + +template <> +SmallVector +collectLoopYieldArgs(OpBuilder &b, LinalgOp clonedOp, + ArrayRef tiledOperands, + SmallVectorImpl &tensorResults) { + auto outputTensorOperands = clonedOp.getOutputTensorOperands(); + size_t numOutputTensors = outputTensorOperands.size(); + + SmallVector yieldArgs(clonedOp->getResults()); + auto tiledOutputOperands = tiledOperands.take_back(numOutputTensors); + yieldArgs.append(tiledOutputOperands.begin(), tiledOutputOperands.end()); + + return yieldArgs; +} + template static Optional tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, @@ -224,7 +265,7 @@ } // 2. Create the tiled loops. - LinalgOp res = op; + LinalgOp clonedOp = op; SmallVector ivs, tensorResults; auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, ValueRange localIvs, @@ -262,30 +303,18 @@ resultTensorTypes.push_back( tiledOperands[opOperand->getOperandNumber()].getType()); - res = op.clone(b, loc, resultTensorTypes, tiledOperands); + clonedOp = op.clone(b, loc, resultTensorTypes, tiledOperands); - // Insert a insert_slice for each output tensor. - unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; - if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor( - b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source())); - } else { - tensorResults.push_back(res->getResult(resultIdx)); - } - ++resultIdx; - } - return scf::ValueVector(tensorResults.begin(), tensorResults.end()); + auto yieldArgs = + collectLoopYieldArgs(b, clonedOp, tiledOperands, tensorResults); + return {yieldArgs.begin(), yieldArgs.end()}; }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, tiledLoopBodyBuilder, options.distribution, options.distributionTypes); // 3. Transform IndexOp results w.r.t. the tiling. - transformIndexOps(b, res, ivs, loopIndexToRangeIndex); + transformIndexOps(b, clonedOp, ivs, loopIndexToRangeIndex); // 4. Gather the newly created loops and return them with the new op. SmallVector loops; @@ -308,8 +337,9 @@ if ((outermostLoop = loop)) break; - return TiledLinalgOp{ - res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; + return TiledLinalgOp{clonedOp, loops, + outermostLoop ? outermostLoop->getResults() + : tensorResults}; } template @@ -500,6 +530,7 @@ memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); PadTensorOp::getCanonicalizationPatterns(patterns, ctx); ctx->getLoadedDialect()->getCanonicalizationPatterns(patterns); + linalg::TiledYieldOp::getCanonicalizationPatterns(patterns, ctx); CanonicalizationPatternList< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -311,9 +311,12 @@ ValueRange ivs, ValueRange inputs, ValueRange outputs) { SmallVector outputTensors = linalgOp.getOutputTensorOperands(); - scf::ValueVector results = + scf::ValueVector yieldArgs = bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors); - nestedBuilder.create(nestedLoc, results); + auto yieldArgsRef = llvm::makeArrayRef(yieldArgs); + nestedBuilder.create( + nestedLoc, yieldArgsRef.take_front(outputTensors.size()), + yieldArgsRef.drop_front(outputTensors.size())); }; SmallVector inputOperands = linalgOp.getInputOperands(); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -36,13 +36,33 @@ %16 = memref.subview %out[%arg3] [%14] [1] : memref<192xf32, #map> to memref linalg.fill(%cst, %16) : f32, memref - linalg.yield + linalg.tiled_yield } return } // ----- +// CHECK-LABEL: @tiled_yield_fold_tensor_cast +func @tiled_yield_fold_tensor_cast(%init: tensor<200xf32>) -> tensor<200xf32> { + %c0 = constant 0 : index + %c10 = constant 10 : index + %c200 = constant 200 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.tiled_loop (%i) = (%c0) to (%c200) step (%c10) + outs (%init_ = %init: tensor<200xf32>) { + %sub_init = tensor.extract_slice %init_[%i] [%c10] [1] + : tensor<200xf32> to tensor + %sub_fill = linalg.fill(%cst, %sub_init) + : f32, tensor -> tensor + linalg.tiled_yield %sub_fill in %sub_init : tensor + } + return %0 : tensor<200xf32> +} +// CHECK: linalg.tiled_yield %{{.*}} in %{{.*}} : tensor<10xf32> + +// ----- + // CHECK-LABEL: zero_rank_reshape_multi func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 @@ -706,8 +726,9 @@ %CT_ = %C_tensor: tensor<48xf32>, %C_ = %C: memref<48xf32>) { %result = call @foo(%A_, %B_, %C_) - : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>) - linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32> + : (memref<48xf32>, tensor<48xf32>, memref<48xf32>) -> (tensor<48xf32>) + linalg.tiled_yield %result in %B_ : tensor<48xf32>, + %CT_ in %CT_ : tensor<48xf32> } return %useful : tensor<48xf32> } @@ -726,7 +747,7 @@ // CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]]) // CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) { // CHECK-NEXT: %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]]) -// CHECK-NEXT: linalg.yield %[[RES]] : +// CHECK-NEXT: linalg.tiled_yield %[[RES]] in %[[B_]] // CHECK: return %[[RESULT]] @@ -743,7 +764,7 @@ ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) outs (%BT_ = %B_tensor: tensor<192xf32>) { %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32> - linalg.yield %0 : tensor<192xf32> + linalg.tiled_yield %0 in %BT_ : tensor<192xf32> } return %result : tensor<192xf32> } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -507,25 +507,25 @@ // of %r3 is read. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-NEXT: linalg.yield + // CHECK-NEXT: linalg.tiled_yield // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} %r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) ins() outs(%t = %B: tensor) { call @some_use(%t) : (tensor) -> () - linalg.yield %t : tensor + linalg.tiled_yield %t in %t : tensor } // %r3 bufferizes inplace fine. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-NEXT: linalg.yield + // CHECK-NEXT: linalg.tiled_yield // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} %r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) ins() outs(%t = %B: tensor) { call @some_use(%t) : (tensor) -> () - linalg.yield %t : tensor + linalg.tiled_yield %t in %t : tensor } return %r1, %r3: tensor, tensor diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -550,10 +550,11 @@ // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]] %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3) - ins (%arg4 = %A: tensor, %use = %effecting : memref, %arg5 = %B: tensor) + ins (%arg4 = %A: tensor, + %use = %effecting : memref, + %arg5 = %B: tensor) outs (%arg6 = %c: tensor) - iterators["reduction"] - { + iterators["reduction"] { // CHECK-NOT: alloc %2 = tensor.dim %arg4, %c0 : tensor @@ -573,8 +574,8 @@ // CHECK: call @some_use(%{{.*}}) : (memref) -> () call @some_use(%use) : (memref) -> () - linalg.yield %8 : tensor - // CHECK: linalg.yield + linalg.tiled_yield %8 in %arg6 : tensor + // CHECK: linalg.tiled_yield // CHECK-NOT: tensor } diff --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir --- a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir +++ b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir @@ -14,7 +14,7 @@ distribution ["block_x", "block_y"] { %0 = call @foo(%A_, %B_) : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> - linalg.yield %0 : tensor<64x64xf32> + linalg.tiled_yield %0 in %B_ : tensor<64x64xf32> } return %0 : tensor<64x64xf32> } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -124,7 +124,7 @@ // TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]] // TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]] -// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = +// TLOOP: %[[ABC_SUB:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]]) // TLOOP-SAME: step (%[[C64]], %[[C16]]) // TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]], @@ -134,18 +134,15 @@ // TLOOP: %[[AB_SUB_SUB:.*]] = tensor.extract_slice %[[AB_SUB_]][0, %[[IV2]]] // TLOOP: %[[C__SUB:.*]] = tensor.extract_slice %[[C__]][%[[IV2]], %[[IV1]]] -// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_SUB_]][0, %[[IV1]]] +// TLOOP: %[[ABC_INIT_SUB_SUB:.*]] = tensor.extract_slice %[[ABC_INIT_SUB_]][0, %[[IV1]]] // TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]]) -// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]] +// TLOOP-SAME: outs(%[[ABC_INIT_SUB_SUB]] : [[TY]]) -> [[TY]] -// TLOOP: %[[RES0:.*]] = tensor.insert_slice %[[ABC_SUB_SUB]] -// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]] -// TLOOP: linalg.yield %[[RES0]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[ABC_SUB_SUB]] in %[[ABC_INIT_SUB_SUB]] : [[TY]] // TLOOP: } -// TLOOP: %[[RES1:.*]] = tensor.insert_slice %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0] -// TLOOP: linalg.yield %[[RES1]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[ABC_SUB]] in %[[ABC_INIT_SUB]] : [[TY]] // TLOOP: } // TLOOP: return %[[ABC]] : [[TY]] @@ -238,10 +235,7 @@ // TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic // TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]]) -// TLOOP: %[[RESULT_SUB:.*]] = tensor.insert_slice -// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]] - -// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[DOUBLE_AB]] in %[[INIT_SUB]] : [[TY]] // TLOOP: } // TLOOP: return %[[RESULT]] : [[TY]] @@ -304,7 +298,8 @@ // TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0] // TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]] // TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] -// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[C0_F32_]], %[[OUT_SUB]]) +// TLOOP: %[[OUT_SUB_2:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[C0_F32_]], %[[OUT_SUB_2]]) // TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) @@ -319,11 +314,9 @@ // TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) // TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] -// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[AB_SUB_SUB]] in %[[INIT_SUB_]] : [[TY]] // TLOOP: } -// TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]] -// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] -// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[AB_SUB]] in %[[OUT_SUB]] : [[TY]] // TLOOP: } // TLOOP: return %[[AB]] : [[TY]] @@ -375,9 +368,10 @@ // TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0] // TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]] // TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] +// TLOOP: %[[OUT_SUB_2:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] // TLOOP: %[[INIT_SUB:.*]] = linalg.generic // TLOOP-SAME: ins(%[[C0_F32_]] -// TLOOP-SAME: outs(%[[OUT_SUB]] +// TLOOP-SAME: outs(%[[OUT_SUB_2]] // TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]]) // TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]]) @@ -392,11 +386,9 @@ // TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) // TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] -// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[AB_SUB_SUB]] in %[[INIT_SUB_]] : [[TY]] // TLOOP: } -// TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]] -// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] -// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[AB_SUB]] in %[[OUT_SUB]] : [[TY]] // TLOOP: } // TLOOP: return %[[AB]] : [[TY]] diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -582,10 +582,6 @@ // ----- -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, %C: memref<192x192xf32>) -> () @@ -603,11 +599,34 @@ call @foo(%A_, %B_, %C_) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}} - linalg.yield + linalg.tiled_yield } return } +// ----- + +func @tiled_loop_incorrect_destination_for_tile(%A: tensor<4xf32>, + %B: tensor<4xf32>) { + %c2 = constant 2 : index + %c4 = constant 2 : index + %c0 = constant 0 : index + %0 = linalg.tiled_loop (%i) = (%c0) to (%c4) step (%c2) + ins (%A_ = %A: tensor<4xf32>) + outs (%B_ = %B: tensor<4xf32>) { + %A_sub = tensor.extract_slice %A_[%i][2][1] + : tensor<4xf32> to tensor<2xf32> + %B_sub = tensor.extract_slice %B_[%i][2][1] + : tensor<4xf32> to tensor<2xf32> + %c0_f32 = constant 0.0 : f32 + %tile = linalg.fill(%c0_f32, %A_sub) : f32, tensor<2xf32> -> tensor<2xf32> + // expected-error @+1 {{expected output 0 to be a subset of the corresponding block argument}} + linalg.tiled_yield %tile in %A_sub : tensor<2xf32> + } + return +} + + // ----- #map0 = affine_map<(d0) -> (24, -d0 + 192)> @@ -630,8 +649,8 @@ %C_ = %C: memref<192x192xf32>) { %1 = call @foo(%A_, %B_, %C_) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor - // expected-error @+1 {{expected yield operand 0 with type = 'tensor' to match output arg type = 'tensor<192x192xf32>}} - linalg.yield %1 : tensor + // expected-error @+1 {{expected tile operand with type = 'tensor' to match output type = 'tensor<192x192xf32>}} + "linalg.tiled_yield" (%1, %CT_) : (tensor, tensor<192x192xf32>) -> () } return } @@ -639,7 +658,7 @@ // ----- func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>) -> () + %C: memref<192x192xf32>) -> (tensor<192x192xf32>) func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>, %B: memref<192x192xf32>, %C: memref<192x192xf32>, @@ -652,9 +671,10 @@ ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>, %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>, %C_: memref<192x192xf32>): - call @foo(%A_, %B_, %C_) - : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - linalg.yield %CT_ : tensor<192x192xf32> + %tile = call @foo(%A_, %B_, %C_) + : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>) + -> (tensor<192x192xf32>) + linalg.tiled_yield %tile in %CT_ : tensor<192x192xf32> }) { iterator_types = ["parallel"], operand_segment_sizes = dense<2> : vector<5xi32> @@ -676,7 +696,7 @@ "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( { ^bb0(%arg4: index, %A_: memref<100xf32>): call @foo(%A_) : (memref<100xf32>)-> () - linalg.yield + linalg.tiled_yield }) { iterator_types = ["parallel"], operand_segment_sizes = dense<[1, 1, 1, 0, 1]> : vector<5xi32> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -648,9 +648,7 @@ linalg.yield %s : i8 } -> tensor - %sum_sub = tensor.insert_slice %sum into %out_[%i, 0][%c4, %c64][1, 1] - : tensor into tensor<24x64xi8> - linalg.yield %sum_sub : tensor<24x64xi8> + linalg.tiled_yield %sum in %out_sub : tensor } return %prod : tensor<24x64xi8> } @@ -711,9 +709,7 @@ linalg.yield %1 : f32 } -> tensor<4xf32> - %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1] - : tensor<4xf32> into tensor<24xf32> - linalg.yield %sum_sub : tensor<24xf32> + linalg.tiled_yield %acc in %sub_out : tensor<4xf32> } return %result : tensor<24xf32> } @@ -773,7 +769,7 @@ %1 = addf %0, %i1d : f32 linalg.yield %1 : f32 } - linalg.yield + linalg.tiled_yield } return } diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -58,8 +58,7 @@ // TLOOP: %[[PROD:.*]] = linalg.matmul ins(%[[SUB_ARG_0]], %[[SUB_ARG_1]] // TLOOP-SE: outs(%[[SUB_ARG_2]] : [[TY]]) -> [[TY]] -// TLOOP: %[[O:.*]] = tensor.insert_slice %[[PROD]] into %[[A2]][%[[I]], %[[J]]] -// TLOOP: linalg.yield %[[O]] : [[TY]] +// TLOOP: linalg.tiled_yield %[[PROD]] in %[[SUB_ARG_2]] : [[TY]] // ----- diff --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir --- a/mlir/test/Dialect/Linalg/tiled-loops.mlir +++ b/mlir/test/Dialect/Linalg/tiled-loops.mlir @@ -29,7 +29,7 @@ linalg.matmul ins(%1, %3 : memref, memref<192x?xf32, #map1>) outs(%4 : memref) - linalg.yield + linalg.tiled_yield } return } @@ -64,7 +64,7 @@ outs (%C_ = %C: memref) iterators["reduction", "reduction"] { linalg.fill(%cst, %A_) : f32, memref<192x192xf32> - linalg.yield + linalg.tiled_yield } return }