diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -324,6 +324,7 @@ //===----------------------------------------------------------------------===// def ForeachThreadOp : SCF_Op<"foreach_thread", [ + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">, RecursiveSideEffects, AutomaticAllocationScope, @@ -335,6 +336,17 @@ parallel body and it takes index operands that indicate how many parallel instances of that function are created. + The op also takes a variadic number of tensor operands (`shared_outs`). + The future buffers corresponding to these tensors are shared among all + threads. Shared tensors should be accessed via their corresponding block + arguments. If multiple threads write to a shared buffer in a racy + fashion, these writes will execute in some unspecified order. Tensors that + are not shared can be used inside the body (i.e., the op is not isolated + from above); however, if a use of such a tensor bufferizes to a memory + write, the tensor is privatized, i.e., a thread-local copy of the tensor is + used. This ensures that memory side effects of a thread are not visible to + other threads (or in the parent body), apart from explicitly shared tensors. + The name "thread" conveys the fact that the parallel execution is mapped (i.e. distributed) to a set of virtual threads of execution, one function application per thread. Further lowerings are responsible for specifying @@ -349,26 +361,20 @@ context of the concrete target the op is lowered to, or to ignore it when the specification is ill-formed or unsupported for a particular target. - The only allowed terminator is `scf.foreach_thread.perform_concurrently`, - which dictates how the partial results of all parallel invocations should be - reconciled into a full value. + The only allowed terminator is `scf.foreach_thread.perform_concurrently`. + `scf.foreach_thread` returns one value per `shared_out` operand. The + actions of the `perform_concurrently` terminators specify how to combine the + partial results of all parallel invocations into a full value, in some + unspecified order. The "destination" of each such op must be a `shared_out` + block argument of the `scf.foreach_thread` op. - `scf.foreach_thread` returns values that are formed by aggregating the - actions of all the `perform_concurrently` terminator of all the virtual - threads, in some unspecified order. - In other words, `scf.foreach_thread` performs all actions specified in the - `perform_concurrently` terminator, after it receives the control back from - its body along each virtual thread of execution. The actions involved in constructing the return values are further described - by [parallel_insert_slice](#parallelinsertslice-parallelinsertsliceop). + by `tensor.parallel_insert_slice`. `scf.foreach_thread` acts as an implicit synchronization point. - Multi-value returns are encoded by including multiple operations inside the - `perform_concurrently` block. - - When the parallel function body has side effects, the order of reads and - writes to memory is unspecified across threads. + When the parallel function body has side effects, their order is unspecified + across threads. Example: @@ -377,7 +383,8 @@ // Sequential context. // %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in - (%num_threads_1, %numthread_id_2) -> (tensor, tensor) { + (%num_threads_1, %numthread_id_2) shared_outs(%o1 = %C, %o2 = %pointwise) + -> (tensor, tensor) { // // Parallel context, each thread with id = (%thread_id_1, %thread_id_2) // runs its version of the code. @@ -386,21 +393,19 @@ tensor to tensor %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]: tensor to tensor - %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]: + %sC = tensor.extract_slice %o1[h((%thread_id_1, %thread_id_2))]: tensor to tensor %sD = matmul ins(%sA, %sB) outs(%sC) - %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]: + %spointwise = subtensor %o2[i((%thread_id_1, %thread_id_2))]: tensor to tensor %sE = add ins(%spointwise) outs(%sD) scf.foreach_thread.perform_concurrently { - // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0. - scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]: + scf.foreach_thread.parallel_insert_slice %sD into %o1[h((%thread_id_1, %thread_id_2))]: tensor into tensor - // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1. - scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]: + scf.foreach_thread.parallel_insert_slice %spointwise into %o2[i((%thread_id_1, %thread_id_2))]: tensor into tensor } } @@ -414,7 +419,8 @@ // Sequential context. // %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in - (%num_threads_1, %numthread_id_2) -> (tensor, tensor) { + (%num_threads_1, %numthread_id_2) shared_outs(...) + -> (tensor, tensor) { // // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)** // runs its version of the code. @@ -426,9 +432,23 @@ // Implicit synchronization point. // Sequential context. // + + Example with privatized tensors: + %t0 = ... + %t1 = ... + %r = scf.foreach_thread ... shared_outs(%o = t0) -> tensor { + // %t0 and %t1 are privatized. %t0 is definitely copied for each thread + // because the scf.foreach_thread op's %t0 use bufferizes to a memory + // write. In the absence of other conflicts, %t1 is copied only if there + // are uses of %t1 in the body that bufferize to a memory read and to a + // memory write. + "some_use"(%t0) + "some_use"(%t1) + } }]; let arguments = (ins Variadic:$num_threads, - DefaultValuedAttr:$thread_dim_mapping); + Variadic:$outputs, + DefaultValuedAttr:$thread_dim_mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -439,19 +459,48 @@ // The default builder does not add the proper body BBargs, roll our own. let skipDefaultBuilders = 1; let builders = [ - // Bodyless builder, result types must be specified. - OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads, + // Bodyless builder, outputs must be specified. + OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, - // Builder that takes a bodyBuilder lambda, result types are inferred from - // the terminator. - OpBuilder<(ins "ValueRange":$num_threads, + // Builder that takes a bodyBuilder lambda. + OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, "ArrayRef":$thread_dim_mapping, "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ int64_t getRank() { return getNumThreads().size(); } - ::mlir::ValueRange getThreadIndices() { return getBody()->getArguments(); } - ::mlir::Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); } + + OpResult getTiedOpResult(OpOperand *opOperand) { + assert(opOperand->getOperandNumber() >= getRank() && "invalid operand"); + return getOperation()->getOpResult( + opOperand->getOperandNumber() - getRank()); + } + + OpOperand *getTiedOpOperand(BlockArgument bbArg) { + assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg"); + return &getOperation()->getOpOperand(bbArg.getArgNumber()); + } + + BlockArgument getTiedBlockArgument(OpOperand *opOperand) { + assert(opOperand->getOperandNumber() >= getRank() && "invalid operand"); + return getBody()->getArgument(opOperand->getOperandNumber()); + } + + ArrayRef getOutputBlockArguments() { + return getBody()->getArguments().drop_front(getRank()); + } + + ::mlir::ValueRange getThreadIndices() { + return getBody()->getArguments().take_front(getRank()); + } + + ::mlir::Value getThreadIndex(int64_t idx) { + return getThreadIndices()[idx]; + } + + ::mlir::Block::BlockArgListType getRegionOutArgs() { + return getBody()->getArguments().drop_front(getRank()); + } // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be @@ -497,7 +546,7 @@ // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can // appear inside perform_concurrently. let extraClassDeclaration = [{ - ::llvm::SmallVector<::mlir::Type> getYieldedTypes(); + ::llvm::SmallVector<::mlir::BlockArgument> getDests(); ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); ::mlir::OpResult getParentResult(int64_t idx); }]; diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -17,11 +17,7 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { let description = [{ - A parallel combining op is an op with a region, that is not isolated from - above and yields values to its parent op without itself returning an SSA - value. The yielded values are determined by subvalues produced by the ops - contained in the region (the `yieldingOps`) and combined in any unspecified - order to produce the values yielded to the parent op. + A parallel combining op is an op with a region. This is useful as a terminator to parallel operations that iterate over some set and return tensors while avoiding tight coupling between the @@ -53,18 +49,6 @@ return $_op.getYieldingOps(); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the contained ops that yield subvalues that this op combines to - yield to its parent. - }], - /*retTy=*/"::llvm::SmallVector<::mlir::Type>", - /*methodName=*/"getYieldedTypes", - /*args=*/(ins), - /*methodBody=*/[{ - return $_op.getYieldedTypes(); - }] - >, ]; // TODO: Single region single block interface on interfaces ? let verify = [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -235,8 +235,8 @@ if (llvm::any_of(loopRanges, hasStrideOne)) return op->emitOpError("only stride-1 supported atm"); // TODO: support `getTiledImplementation` with >1 produced tiled ops. - auto destOperands = op.getDestinationOperands(b); - if (destOperands.size() != 1) + auto dest = op.getDestinationOperands(b); + if (dest.size() != 1) return op->emitOpError("only single dest operand supported atm"); SmallVector nonZeroNumThreads = @@ -255,8 +255,7 @@ // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads), - threadDimMapping); + loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping); // Fill out the ForeachThreadOp body. b.setInsertionPointToStart(foreachThreadOp.getBody(0)); @@ -317,17 +316,34 @@ ++threadIdIdx; } + // Clone the tileable op and update its destination operands to use the output + // bbArgs of the ForeachThreadOp. + ArrayRef destBbArgs = + foreachThreadOp.getOutputBlockArguments(); + Operation *clonedOp = b.clone(*op.getOperation()); + auto destinationStyleOp = dyn_cast(clonedOp); + if (destinationStyleOp) { + for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) { + auto it = llvm::find(dest, outOperand->get()); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + outOperand->set(destBbArgs[destNum]); + } + } + + // Tile the cloned op and delete the clone. SmallVector tiledOps = - op.getTiledImplementation(b, tiledOffsets, tiledSizes); + cast(clonedOp).getTiledImplementation(b, tiledOffsets, + tiledSizes); + b.eraseOp(clonedOp); assert(tiledOps.size() == 1 && "expected a single produced tiled op"); tiledOp = tiledOps.front(); auto tilingInterfaceOp = dyn_cast(tiledOp); assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); - for (auto it : - llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())), - tilingInterfaceOp->getResults(), destOperands)) { + for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), + tilingInterfaceOp->getResults(), destBbArgs)) { b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); SmallVector resultOffsets, resultSizes; if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1055,26 +1055,25 @@ if (failed(getTerminator().verify())) return failure(); - // Check that the body defines as single block argument for the thread index. + // Check number of outputs. + if (getNumResults() != getOutputs().size()) + return emitOpError("produces ") + << getNumResults() << " results, but has only " + << getOutputs().size() << " outputs"; + + // Check that the body defines block arguments for thread indices and outputs. auto *body = getBody(); - if (body->getNumArguments() != getRank()) + if (body->getNumArguments() != getRank() + getOutputs().size()) return emitOpError("region expects ") << getRank() << " arguments"; + for (int64_t i = 0; i < getRank(); ++i) + if (!body->getArgument(i).getType().isIndex()) + return emitOpError("expects ") + << i << "-th block argument to be an index"; + for (unsigned i = 0; i < getOutputs().size(); ++i) + if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) + return emitOpError("type mismatch between ") + << i << "-th output and corresponding block argument"; - // Verify consistency between the result types and the terminator. - auto terminatorTypes = getTerminator().getYieldedTypes(); - auto opResults = getResults(); - if (opResults.size() != terminatorTypes.size()) - return emitOpError("produces ") - << opResults.size() << " results, but its terminator yields " - << terminatorTypes.size() << " value(s)"; - unsigned i = 0; - for (auto e : llvm::zip(terminatorTypes, opResults)) { - if (std::get<0>(e) != std::get<1>(e).getType()) - return emitOpError() << "type mismatch between result " << i << " (" - << std::get<1>(e).getType() << ") and terminator (" - << std::get<0>(e) << ")"; - i++; - } return success(); } @@ -1083,11 +1082,16 @@ llvm::interleaveComma(getThreadIndices(), p); p << ") in ("; llvm::interleaveComma(getNumThreads(), p); - p << ") -> (" << getResultTypes() << ") "; + p << ")"; + printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); + p << " "; + if (!getRegionOutArgs().empty()) + p << "-> (" << getResultTypes() << ") "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/getNumResults() > 0); - p.printOptionalAttrDict(getOperation()->getAttrs()); + p.printOptionalAttrDict(getOperation()->getAttrs(), + {"operand_segment_sizes"}); } ParseResult ForeachThreadOp::parse(OpAsmParser &parser, @@ -1109,15 +1113,34 @@ result.operands)) return failure(); - // Parse optional results. - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); + // Parse out operands and results. + SmallVector regionOutArgs; + SmallVector outOperands; + SMLoc outOperandsLoc = parser.getCurrentLocation(); + if (succeeded(parser.parseOptionalKeyword("shared_outs"))) { + if (outOperands.size() != result.types.size()) + return parser.emitError(outOperandsLoc, + "mismatch between out operands and types"); + if (parser.parseAssignmentList(regionOutArgs, outOperands) || + parser.parseOptionalArrowTypeList(result.types) || + parser.resolveOperands(outOperands, result.types, outOperandsLoc, + result.operands)) + return failure(); + } // Parse region. + SmallVector regionArgs; std::unique_ptr region = std::make_unique(); - for (auto &idx : threadIndices) + for (auto &idx : threadIndices) { idx.type = builder.getIndexType(); - if (parser.parseRegion(*region, threadIndices)) + regionArgs.push_back(idx); + } + for (const auto &it : llvm::enumerate(regionOutArgs)) { + auto &out = it.value(); + out.type = result.types[it.index()]; + regionArgs.push_back(out); + } + if (parser.parseRegion(*region, regionArgs)) return failure(); // Ensure terminator and move region. @@ -1128,19 +1151,27 @@ // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(threadNums.size()), + static_cast(outOperands.size())})); return success(); } -// Bodyless builder, result types must be specified. +// Bodyless builder, outputs must be specified. void ForeachThreadOp::build(mlir::OpBuilder &builder, - mlir::OperationState &result, TypeRange resultTypes, + mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, ArrayRef threadDimMapping) { result.addOperands(numThreads); + result.addOperands(outputs); + result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), + builder.getI64ArrayAttr(threadDimMapping)); result.addAttribute( - // TODO: getThreadDimMappingAttrName() but it is not a static member. - "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping)); + "operand_segment_sizes", + builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), + static_cast(outputs.size())})); + result.addTypes(TypeRange(outputs)); Region *bodyRegion = result.addRegion(); OpBuilder::InsertionGuard g(builder); @@ -1149,40 +1180,51 @@ // expects it .. builder.createBlock(bodyRegion); Block &bodyBlock = bodyRegion->front(); + // Add block arguments for indices and outputs. bodyBlock.addArguments( SmallVector(numThreads.size(), builder.getIndexType()), SmallVector(numThreads.size(), result.location)); + bodyBlock.addArguments( + TypeRange(outputs), + SmallVector(outputs.size(), result.location)); ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location); - result.addTypes(resultTypes); } -// Builder that takes a bodyBuilder lambda, result types are inferred from -// the terminator. +// Builder that takes a bodyBuilder lambda. void ForeachThreadOp::build( - mlir::OpBuilder &builder, mlir::OperationState &result, + mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, ArrayRef threadDimMapping, function_ref bodyBuilder) { result.addOperands(numThreads); + result.addOperands(outputs); + result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), + builder.getI64ArrayAttr(threadDimMapping)); result.addAttribute( - // TODO: getThreadDimMappingAttrName() but it is not a static member. - "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping)); + "operand_segment_sizes", + builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), + static_cast(outputs.size())})); + result.addTypes(TypeRange(outputs)); - OpBuilder::InsertionGuard g(builder); Region *bodyRegion = result.addRegion(); + OpBuilder::InsertionGuard g(builder); builder.createBlock(bodyRegion); Block &bodyBlock = bodyRegion->front(); + // Add block arguments for indices and outputs. bodyBlock.addArguments( SmallVector(numThreads.size(), builder.getIndexType()), SmallVector(numThreads.size(), result.location)); + bodyBlock.addArguments( + TypeRange(outputs), + SmallVector(outputs.size(), result.location)); - OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); bodyBuilder(builder, result.location, bodyBlock.getArguments()); +#ifndef NDEBUG auto terminator = llvm::dyn_cast(bodyBlock.getTerminator()); assert(terminator && "expected bodyBuilder to create PerformConcurrentlyOp terminator"); - result.addTypes(terminator.getYieldedTypes()); +#endif // NDEBUG } // The ensureTerminator method generated by SingleBlockImplicitTerminator is @@ -1223,12 +1265,23 @@ } LogicalResult PerformConcurrentlyOp::verify() { + scf::ForeachThreadOp foreachThreadOp = + dyn_cast(getOperation()->getParentOp()); + if (!foreachThreadOp) + return this->emitOpError("expected foreach_thread op parent"); + // TODO: PerformConcurrentlyOpInterface. - for (const Operation &op : getRegion().front().getOperations()) { + for (Operation &op : getRegion().front().getOperations()) { if (!isa(op)) { return this->emitOpError("expected only ") << tensor::ParallelInsertSliceOp::getOperationName() << " ops"; } + + // Verify that inserts are into out block arguments. + Value dest = cast(op).getDest(); + ArrayRef regionOutArgs = foreachThreadOp.getRegionOutArgs(); + if (llvm::find(regionOutArgs, dest) == regionOutArgs.end()) + return op.emitOpError("may only insert into an output block argument"); } return success(); } @@ -1264,11 +1317,12 @@ return getOperation()->getParentOp()->getResult(idx); } -SmallVector PerformConcurrentlyOp::getYieldedTypes() { +SmallVector PerformConcurrentlyOp::getDests() { return llvm::to_vector<4>( llvm::map_range(getYieldingOps(), [](Operation &op) { - auto insertSliceOp = dyn_cast(&op); - return insertSliceOp ? insertSliceOp.yieldedType() : Type(); + // Add new ops here as needed. + auto insertSliceOp = cast(&op); + return insertSliceOp.getDest().cast(); })); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1054,18 +1054,6 @@ } }; -/// Return the destinations that an ForeachThreadOp is inserting into. One per -/// ParallelInsertSliceOp. -static SmallVector -getInsertionDest(ForeachThreadOp foreachThreadOp) { - PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); - SmallVector result; - terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) { - result.push_back(&insertOp->getOpOperand(1) /*dest*/); - }); - return result; -} - /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the /// region. There are op interfaces for the terminators (PerformConcurrentlyOp /// and ParallelInsertSliceOp), but these are only used during analysis. Not @@ -1073,57 +1061,114 @@ struct ForeachThreadOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the + // uses of its matching bbArg may. auto foreachThreadOp = cast(op); - return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; + return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand)); } - bool isMemoryWrite(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // This op is a memory write. Stop lookup here to avoid finding false - // conflicts involving this op and one of the ops in the region. This is - // similar to how scf.if ops are analyzed. + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Outputs of scf::ForeachThreadOps are always considered as a write. return true; } + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto foreachThreadOp = cast(op); + return {foreachThreadOp.getTiedOpResult(&opOperand)}; + } + BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + return true; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { + OpBuilder::InsertionGuard guard(rewriter); auto foreachThreadOp = cast(op); + int64_t rank = foreachThreadOp.getRank(); -#ifndef NDEBUG - // ParallelInsertSliceOpInterface replaces all uses. - for (OpResult opResult : foreachThreadOp->getOpResults()) - assert(opResult.getUses().empty() && - "expected that all uses were already replaced"); -#endif // NDEBUG + // Get buffers for all output operands. + SmallVector buffers; + for (Value out : foreachThreadOp.getOutputs()) { + FailureOr buffer = getBuffer(rewriter, out, options); + if (failed(buffer)) + return failure(); + buffers.push_back(*buffer); + } + + // Use buffers instead of block arguments. + rewriter.setInsertionPointToStart(foreachThreadOp.getBody()); + for (const auto &it : + llvm::zip(foreachThreadOp.getBody()->getArguments().drop_front(rank), + buffers)) { + BlockArgument bbArg = std::get<0>(it); + Value buffer = std::get<1>(it); + Value bufferAsTensor = + rewriter.create(foreachThreadOp.getLoc(), buffer); + bbArg.replaceAllUsesWith(bufferAsTensor); + } // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. - TypeRange newResultTypes; + rewriter.setInsertionPoint(foreachThreadOp); auto newForeachThreadOp = rewriter.create( - foreachThreadOp.getLoc(), newResultTypes, + foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), foreachThreadOp.getNumThreads(), extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. + SmallVector replacementBbArgs; + replacementBbArgs.append( + newForeachThreadOp.getBody()->getArguments().begin(), + newForeachThreadOp.getBody()->getArguments().end()); + replacementBbArgs.append(foreachThreadOp.getOutputs().size(), Value()); rewriter.mergeBlocks(foreachThreadOp.getBody(), - newForeachThreadOp.getBody(), - {newForeachThreadOp.getBody()->getArguments()}); + newForeachThreadOp.getBody(), replacementBbArgs); - // Remove the old op. - rewriter.eraseOp(op); + // Remove the old op and replace all of its uses. + replaceOpWithBufferizedValues(rewriter, op, buffers); return success(); } + + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto foreachThreadOp = cast(op); + + if (auto bbArg = value.dyn_cast()) + // A tensor block argument has the same bufferized type as the + // corresponding output operand. + return bufferization::getBufferType( + foreachThreadOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes); + + // The bufferized result type is the same as the bufferized type of the + // corresponding output operand. + return bufferization::getBufferType( + foreachThreadOp.getOutputs()[value.cast().getResultNumber()], + options, fixedTypes); + } + + bool isRepetitiveRegion(Operation *op, unsigned index) const { + auto foreachThreadOp = cast(op); + // This op is not repetitive if it has just a single thread. + if (llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) { + return getConstantIntValue(v) == static_cast(1); + })) + return false; + return true; + } }; /// Nothing to do for PerformConcurrentlyOp. diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -922,12 +922,7 @@ ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand != &op->getOpOperand(1) /*dest*/) return {}; - - // ParallelInsertSliceOp itself has no results, query its tied op results. - auto insertOp = cast(op); - return {insertOp.getTiedOpResult()}; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, @@ -940,84 +935,21 @@ return &opOperand == &op->getOpOperand(1) /*dest*/; } - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { - // This interface method is overridden because we want to set a custom - // insertion point for tensor copies. They should be inserted right before - // the ForeachThreadOp. E.g.: - // - // %r0, %r1 = foreach_thead ... { - // ... - // perform_concurrently { - // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} - // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} - // } - // } - // - // After TensorCopyInsertion: - // - // %copy = bufferization.alloc_tensor() copy(%d) - // %r0, %r1 = foreach_thead ... { - // ... - // perform_concurrently { - // parallel_insert_slice %a into %b ... - // parallel_insert_slice %c into %copy ... - // } - // } - - OpBuilder::InsertionGuard g(rewriter); - auto parallelInsertSliceOp = cast(op); - ParallelCombiningOpInterface parallelCombiningParent = - parallelInsertSliceOp.getParallelCombiningParent(); - Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); - - // Nothing to do if the destination tensor is inplace. - assert(state.isInPlace(op->getOpOperand(0) /*src*/) && - "source is always in-place"); - if (state.isInPlace(op->getOpOperand(1) /*dest*/)) - return success(); - - // Find corresponding OpResult. - OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); - - // Insert tensor allocation right before the ForeachThreadOp. - rewriter.setInsertionPoint(parallelIteratingOp); - bool isYielded = state.isTensorYielded(opResult); - FailureOr alloc = allocateTensorForShapedValue( - rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), - /*escape=*/isYielded, state.getOptions()); - if (failed(alloc)) - return failure(); - - // Update destination operand. - rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { - parallelInsertSliceOp.getDestMutable().assign(*alloc); - }); - - return success(); - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto parallelInsertSliceOp = cast(op); ParallelCombiningOpInterface parallelCombiningParent = parallelInsertSliceOp.getParallelCombiningParent(); - Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); - // Get destination buffer. + // Bufferize the op outside of the parallel combining terminator. + rewriter.setInsertionPoint(parallelCombiningParent); + + // Get source and destination buffers. FailureOr destBuffer = getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); if (failed(destBuffer)) return failure(); - - // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. - rewriter.setInsertionPoint(parallelCombiningParent); FailureOr srcBuffer = getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); if (failed(srcBuffer)) @@ -1043,18 +975,7 @@ *srcBuffer, subview))) return failure(); - // Replace all uses of parallelIteratingOp (just the corresponding result). - rewriter.setInsertionPointAfter(parallelIteratingOp); - Value toTensorOp = - rewriter.create(parallelIteratingOp->getLoc(), *destBuffer); - // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. - SmallVector resultUses = llvm::to_vector( - llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), - [](OpOperand &use) { return &use; })); - for (OpOperand *use : resultUses) { - rewriter.updateRootInPlace(use->getOwner(), - [&]() { use->set(toTensorOp); }); - } + // Delete the op. rewriter.eraseOp(op); return success(); } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -835,16 +835,16 @@ %c4 = arith.constant 4 : index %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32> - %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) { + %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) { %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32> scf.foreach_thread.perform_concurrently { // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}} // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<4x2xf32> - tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] : + tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x2xf32> } - } + } return %res: tensor<4x2xf32> } diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -15,15 +15,15 @@ func.func @matmul(%A: tensor, %B: tensor, %C: tensor) -> tensor { // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index - // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor) { + // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor) { // CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor to tensor // CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor to tensor - // CHECK: %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor to tensor + // CHECK: %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor to tensor // CHECK: %[[RES:.*]] = linalg.matmul // CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor, tensor) // CHECK-SAME: outs(%[[tC]] : tensor) -> tensor // CHECK: scf.foreach_thread.perform_concurrently { - // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} : + // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} @@ -55,10 +55,10 @@ // CHECK-SAME: %[[A:[0-9a-z]+]]: tensor // CHECK-SAME: %[[B:[0-9a-z]+]]: tensor // CHECK-SAME: %[[C:[0-9a-z]+]]: tensor -func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { +func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index // CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]]) // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) // CHECK-NOT: affine.min @@ -67,7 +67,7 @@ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -104,14 +104,14 @@ // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]] // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]] // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) - // CHECK tensor.extract_slice %[[A]] // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) - // CHECK tensor.extract_slice %[[B]] - // CHECK tensor.extract_slice %[[C]] + // CHECK: tensor.extract_slice %[[A]] + // CHECK: tensor.extract_slice %[[B]] + // CHECK: tensor.extract_slice %[[C_BLK]] // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -144,7 +144,7 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : // CHECK-DAG: %[[c15:.+]] = arith.constant 15 : - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]]) // CHECK-NOT: affine.max // CHECK-NOT: affine.min @@ -152,7 +152,7 @@ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]]) // CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] : // CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] : - // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : + // CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] : // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice @@ -199,7 +199,7 @@ // CHECK-LABEL: extract_source( // CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) { +// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) { // CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]]) // CHECK: scf.foreach_thread.perform_concurrently { // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32> @@ -227,10 +227,10 @@ // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]] // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]] - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) - // CHECK tensor.extract_slice %[[A]] - // CHECK tensor.extract_slice %[[B]] - // CHECK tensor.extract_slice %[[C]] + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: tensor.extract_slice %[[A]] + // CHECK: tensor.extract_slice %[[B]] + // CHECK: tensor.extract_slice %[[C_BLK]] // CHECK: linalg.matmul // CHECK: scf.foreach_thread.perform_concurrently // CHECK-NEXT: tensor.parallel_insert_slice diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -17,10 +17,10 @@ %1 = affine.apply #map0()[%d0, %arg0] // CHECK: scf.foreach_thread {{.*}} { - %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor) { + %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor) { %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%d0, %arg0] - %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor to tensor + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] @@ -29,7 +29,7 @@ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: } @@ -70,16 +70,16 @@ %1 = affine.apply #map0()[%arg0] // CHECK: scf.foreach_thread {{.*}} { - %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) { + %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) { // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%arg0] - %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]] %7 = linalg.elemwise_unary ins(%0 : tensor) outs(%5 : tensor) -> tensor scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor into tensor<64xf32> + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor<64xf32> } } // CHECK: } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -527,11 +527,11 @@ %c1 = arith.constant 1 : index %num_threads = arith.constant 100 : index - // expected-error @+1 {{produces 2 results, but its terminator yields 1 value(s)}} - %result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) { + // expected-error @+1 {{1 operands present, but expected 2}} + %result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>, tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } } @@ -540,14 +540,14 @@ // ----- -func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) { +func.func @invalid_insert_dest(%in: tensor<100xf32>, %out: tensor<100xf32>) { %c1 = arith.constant 1 : index %num_threads = arith.constant 100 : index - // expected-error @+1 {{type mismatch between result 0 ('tensor') and terminator ('tensor<100xf32>')}} - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor) { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { + // expected-error @+1 {{may only insert into an output block argument}} tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } @@ -561,11 +561,11 @@ %c1 = arith.constant 1 : index %num_threads = arith.constant 100 : index - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> // expected-error @+1 {{expected only tensor.parallel_insert_slice ops}} scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> %0 = arith.constant 1: index } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -120,14 +120,14 @@ // CHECK-FUNC-NOT: alloc_tensor // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[arg1]]) {bufferization.escape = [false]} : tensor<100xf32> - // CHECK: scf.foreach_thread - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { + // CHECK: scf.foreach_thread {{.*}} shared_outs(%[[o:.*]] = %[[alloc]]) + %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> { // CHECK: tensor.extract_slice // CHECK: scf.foreach_thread.perform_concurrently - // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]] + // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[o]] %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } // CHECK: } {thread_dim_mapping = [5]} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -525,10 +525,10 @@ %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> () - %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor) { + // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) + %2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor) { // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1] - %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor to tensor + %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor to tensor // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref) -> tensor // Self-copy will DCE away later. @@ -538,7 +538,7 @@ // CHECK-NOT: scf.foreach_thread.perform_concurrently // CHECK-NOT: parallel_insert_slice scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] : tensor into tensor } } @@ -571,26 +571,22 @@ // CHECK: %[[alloc1:.*]] = memref.alloc // CHECK: memref.copy %[[arg2]], %[[alloc1]] - // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> () - %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor) { - // Another alloc for the extract_slice op. - // CHECK: %[[alloc2:.*]] = memref.alloc - %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor to tensor + // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) + %2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor) { + // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1] + %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor to tensor - // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref) -> tensor - // Now the copy of the actual insert_slice. - // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1] - // - // CHECK: memref.copy %[[alloc2]], %[[subview1]] - // CHECK: memref.dealloc %[[alloc2]] + // Now the copy of the actual insert_slice. (It will fold away.) + // CHECK: memref.copy %[[subview1]], %[[subview1]] // Empty terminator is elided from pretty-printing. // CHECK-NOT: scf.foreach_thread.perform_concurrently // CHECK-NOT: parallel_insert_slice scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] : tensor into tensor } } @@ -617,18 +613,18 @@ %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index - // CHECK: scf.foreach_thread {{.*}} -> () - %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) { + // CHECK: scf.foreach_thread {{.*}} + %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) shared_outs(%o = %arg2) -> (tensor<8x8xf32>) { %1 = affine.apply #map0(%arg3) %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32> %4 = affine.apply #map1(%arg4) %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32> - %7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32> - + %7 = tensor.extract_slice %o[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32> + // CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>) %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32> scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> + tensor.parallel_insert_slice %8 into %o[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> } } return %0 : tensor<8x8xf32> @@ -636,6 +632,71 @@ // ----- +// CHECK-LABEL: func @scf_foreach_private_var( +// CHECK-SAME: %[[t:.*]]: memref<10xf32 +func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + + // A copy is inserted for the uses of %t in the loop. + // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32> + // CHECK: memref.copy %[[t]], %[[t_copy]] + + // CHECK: scf.foreach_thread (%{{.*}}) in (%{{.*}}) { + + // Load from the copy and store into the shared output. + // CHECK: %[[subview:.*]] = memref.subview %[[t]] + // CHECK: memref.load %[[t_copy]] + // CHECK: memref.store %{{.*}}, %[[subview]] + %0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t) -> tensor<10xf32> { + %offset = arith.muli %c5, %tid : index + %slice = tensor.extract_slice %o[%offset] [5] [1] + : tensor<10xf32> to tensor<5xf32> + %r2 = tensor.extract %t[%tid] : tensor<10xf32> + %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32> + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %i into %o[%offset] [5] [1] + : tensor<5xf32> into tensor<10xf32> + } + } + + %r = tensor.extract %0[%c2] : tensor<10xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: func.func @scf_foreach_privatized_but_not_copied( +// CHECK-SAME: %[[t0:.*]]: memref<10xf32, {{.*}}>, %[[t1:.*]]: memref<10xf32 +func.func @scf_foreach_privatized_but_not_copied( + %t0: tensor<10xf32>, %t1: tensor<10xf32>) -> f32 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + + // CHECK-NOT: memref.alloc + // CHECK-NOT: memref.copy + // CHECK: scf.foreach_thread {{.*}} { + %0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t0) -> tensor<10xf32> { + %offset = arith.muli %c5, %tid : index + %slice = tensor.extract_slice %o[%offset] [5] [1] + : tensor<10xf32> to tensor<5xf32> + + // %t1 is never written in here, so no copy is needed + // CHECK: memref.load %[[t1]] + %r2 = tensor.extract %t1[%tid] : tensor<10xf32> + %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32> + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %i into %o[%offset] [5] [1] + : tensor<5xf32> into tensor<10xf32> + } + } + + %r = tensor.extract %0[%c2] : tensor<10xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: func @scf_if_memory_space func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32) { diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -323,10 +323,10 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } } @@ -340,7 +340,7 @@ // CHECK: scf.foreach_thread // CHECK-NEXT: } {thread_dim_mapping = [42]} // CHECK-NEXT: return - scf.foreach_thread (%thread_idx) in (%num_threads) -> () { + scf.foreach_thread (%thread_idx) in (%num_threads) { scf.foreach_thread.perform_concurrently { } } {thread_dim_mapping = [42]} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1455,13 +1455,13 @@ %c1 = arith.constant 1 : index // CHECK-NOT: tensor.cast - // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor) { + // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor) { // CHECK-NEXT: scf.foreach_thread.perform_concurrently { - // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1] - %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor) { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1] + %2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor) { %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor + tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor into tensor } } return %2 : tensor @@ -1477,12 +1477,12 @@ { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - // CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) { + // CHECK: scf.foreach_thread () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) { // CHECK-NEXT: scf.foreach_thread.perform_concurrently { - // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32> - %2 = scf.foreach_thread () in () -> (tensor<1x5xf32>) { + // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32> + %2 = scf.foreach_thread () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) { scf.foreach_thread.perform_concurrently { - tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32> + tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32> } } return %2 : tensor<1x5xf32> diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -205,12 +205,12 @@ %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread {{.*}} { - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<200x100xf32> { %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> scf.foreach_thread.perform_concurrently { // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]> // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]> - tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] : + tensor.parallel_insert_slice %1 into %o[1, %thread_idx][1, 1][1, 1] : tensor<1xf32> into tensor<200x100xf32> } }