diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -353,16 +353,16 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ AttrSizedOperandSegments, - SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">, - RecursiveMemoryEffects, AutomaticAllocationScope, - ]> { + RecursiveMemoryEffects, + SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">, + ]> { let summary = "evaluate a block multiple times in parallel"; let description = [{ `scf.foreach_thread` is a target-independent multi-dimensional parallel region application operation. It has exactly one block that represents the - parallel body and it takes index operands that indicate how many parallel - instances of that function are created. + parallel body and it takes index operands that specify lower bounds, upper + bounds and steps. The op also takes a variadic number of tensor operands (`shared_outs`). The future buffers corresponding to these tensors are shared among all @@ -404,7 +404,12 @@ When the parallel function body has side effects, their order is unspecified across threads. - Example: + `scf.foreach_thread` can be printed in two different ways depending on + whether the loop is normalized or not. The loop is 'normalized' when all + lower bounds are equal to zero and steps are equal to one. In that case, + `lowerBound` and `step` operands will be omitted during printing. + + Normalized loop example: ```mlir // @@ -442,6 +447,38 @@ // ``` + Loop with loop bounds example: + + ```mlir + // + // Sequential context. + // + %pointwise = scf.foreach_thread (%i, %j) = (0, 0) to (%dim1, %dim2) + step (%tileSize1, %tileSize2) shared_outs(%o1 = %out) + -> (tensor, tensor) { + // + // Parallel context. + // + %sA = tensor.extract_slice %A[%i, %j][%tileSize1, %tileSize2][1, 1] + : tensor to tensor + %sB = tensor.extract_slice %B[%i, %j][%tileSize1, %tileSize2][1, 1] + : tensor to tensor + %sC = tensor.extract_slice %o[%i, %j][%tileSize1, %tileSize2][1, 1] + : tensor to tensor + + %add = map {"arith.addf"} ins(%sA, %sB) outs(%sC) + + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %add into + %o[%i, %j][%tileSize1, %tileSize2][1, 1] + : tensor into tensor + } + } + // Implicit synchronization point. + // Sequential context. + // + ``` + Example with mapping attribute: ```mlir @@ -481,9 +518,15 @@ } ``` }]; - let arguments = (ins Variadic:$num_threads, - Variadic:$outputs, - OptionalAttr:$mapping); + let arguments = (ins + Variadic:$dynamicLowerBound, + Variadic:$dynamicUpperBound, + Variadic:$dynamicStep, + DenseI64ArrayAttr:$staticLowerBound, + DenseI64ArrayAttr:$staticUpperBound, + DenseI64ArrayAttr:$staticStep, + Variadic:$outputs, + OptionalAttr:$mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -495,58 +538,114 @@ // The default builder does not add the proper body BBargs, roll our own. let skipDefaultBuilders = 1; let builders = [ - // Bodyless builder, outputs must be specified. - OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "std::optional":$mapping)>, - // Builder that takes a bodyBuilder lambda. - OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "ArrayRef":$mapping, - "function_ref":$bodyBuilder)> + // Builder that takes loop bounds. + OpBuilder<(ins "ArrayRef":$lbs, + "ArrayRef":$ubs, "ArrayRef":$steps, + "ValueRange":$outputs, "std::optional":$mapping, + CArg<"function_ref", + "nullptr"> :$bodyBuilderFn)>, + + // Builder for normalized loop that takes only upper bounds. + OpBuilder<(ins "ArrayRef":$ubs, + "ValueRange":$outputs, "std::optional":$mapping, + CArg<"function_ref", + "nullptr"> :$bodyBuilderFn)>, ]; + let extraClassDeclaration = [{ - int64_t getRank() { return getNumThreads().size(); } + // Get lower bounds as OpFoldResult. + SmallVector getMixedLowerBound() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); + } + + // Get upper bounds as OpFoldResult. + SmallVector getMixedUpperBound() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); + } + + // Get steps as OpFoldResult. + SmallVector getMixedStep() { + Builder b(getOperation()->getContext()); + return getMixedValues(getStaticStep(), getDynamicStep(), b); + } + + /// Get lower bounds as values. + SmallVector getLowerBound(OpBuilder &b) { + return getAsValues(b, getLoc(), getMixedLowerBound()); + } + + /// Get upper bounds as values. + SmallVector getUpperBound(OpBuilder &b) { + return getAsValues(b, getLoc(), getMixedUpperBound()); + } + + /// Get steps as values. + SmallVector getStep(OpBuilder &b) { + return getAsValues(b, getLoc(), getMixedStep()); + } + + int64_t getRank() { return getStaticLowerBound().size(); } + + /// Number of operands controlling the loop: lbs, ubs, steps + unsigned getNumControlOperands() { return 3 * getRank(); } + + /// Number of dynamic operands controlling the loop: lbs, ubs, steps + unsigned getNumDynamicControlOperands() { + return getODSOperandIndexAndLength(3).first; + } OpResult getTiedOpResult(OpOperand *opOperand) { - assert(opOperand->getOperandNumber() >= getRank() && "invalid operand"); + assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() && + "invalid operand"); return getOperation()->getOpResult( - opOperand->getOperandNumber() - getRank()); + opOperand->getOperandNumber() - getNumDynamicControlOperands()); } /// Return the num_threads operand that is tied to the given thread id /// block argument. OpOperand *getTiedOpOperand(BlockArgument bbArg) { assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg"); - return &getOperation()->getOpOperand(bbArg.getArgNumber()); + + return &getOperation()->getOpOperand(getNumDynamicControlOperands() + + bbArg.getArgNumber() - getRank()); } /// Return the shared_outs operand that is tied to the given OpResult. OpOperand *getTiedOpOperand(OpResult opResult) { assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); - return &getOperation()->getOpOperand( - opResult.getResultNumber() + getRank()); + return &getOperation()->getOpOperand(getNumDynamicControlOperands() + + opResult.getResultNumber()); } BlockArgument getTiedBlockArgument(OpOperand *opOperand) { - assert(opOperand->getOperandNumber() >= getRank() && "invalid operand"); - return getBody()->getArgument(opOperand->getOperandNumber()); + assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() && + "invalid operand"); + + return getBody()->getArgument(opOperand->getOperandNumber() - + getNumDynamicControlOperands() + getRank()); } ArrayRef getOutputBlockArguments() { return getBody()->getArguments().drop_front(getRank()); } - ::mlir::ValueRange getThreadIndices() { + ::mlir::ValueRange getInductionVars() { return getBody()->getArguments().take_front(getRank()); } - ::mlir::Value getThreadIndex(int64_t idx) { - return getThreadIndices()[idx]; + ::mlir::Value getInductionVar(int64_t idx) { + return getInductionVars()[idx]; } ::mlir::Block::BlockArgListType getRegionOutArgs() { return getBody()->getArguments().drop_front(getRank()); } + /// Checks if the lbs are zeros and steps are ones. + bool isNormalized(); + /// Helper to sort `values` according to matching `keys`. /// Take a custom `compare` binary comparator which returns true if the first /// element is smaller than the second (i.e. compatible with std::sort). @@ -559,7 +658,8 @@ // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be // well-formed. We override it here to ensure that we do the right thing. - static void ensureTerminator(Region ®ion, OpBuilder &builder, Location loc); + static void ensureTerminator(Region & region, OpBuilder & builder, + Location loc); PerformConcurrentlyOp getTerminator(); }]; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -48,8 +48,10 @@ /// in `integers` is `dynVal` or (2) the next value otherwise. This allows /// idiomatic printing of mixed value and integer attributes in a list. E.g. /// `[%arg0, 7, 42, %arg42]`. -void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayRef integers); +void printDynamicIndexList( + OpAsmPrinter &printer, Operation *op, OperandRange values, + ArrayRef integers, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); /// Pasrer hook for custom directive in assemblyFormat. /// @@ -64,10 +66,11 @@ /// E.g. after parsing "[%arg0, 7, 42, %arg42]": /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". -ParseResult -parseDynamicIndexList(OpAsmParser &parser, - SmallVectorImpl &values, - DenseI64ArrayAttr &integers); +ParseResult parseDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -180,16 +180,19 @@ // transform op does not apply to individual ForeachThreadOp. Location loc = foreachThreadOp->getLoc(); + if (!foreachThreadOp.isNormalized()) + return transformOp.emitSilenceableError() + << "unsupported non-normalized loops"; if (foreachThreadOp.getNumResults() > 0) return transformOp.emitSilenceableError() << "only bufferized scf.foreach_thread lowers to " "gpu.block_id"; - if (foreachThreadOp.getNumThreads().size() > 3) + if (foreachThreadOp.getRank() > 3) return transformOp.emitSilenceableError() << "scf.foreach_thread with rank > 3 does not lower to " "gpu.block_id"; - if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) { - return !v.getDefiningOp(); + if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); })) { return transformOp.emitSilenceableError() << "unsupported dynamic griddim size"; @@ -198,8 +201,7 @@ llvm::to_vector(foreachThreadOp.getMapping()->getValue()); // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary. - SmallVector numBlocks = - llvm::to_vector(foreachThreadOp.getNumThreads()); + SmallVector numBlocks = foreachThreadOp.getUpperBound(rewriter); // Ensure we have 3 block sizes, one for each id. Value one; for (auto attr : mappingAttributes) { @@ -227,7 +229,7 @@ blockIdGenerator(rewriter, foreachThreadOp, blockOps); IRMapping bvm; for (auto [blockIdx, blockDim] : - llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) { + llvm::zip(foreachThreadOp.getInductionVars(), blockMapping)) { bvm.map(blockIdx, blockOps[static_cast( blockDim.cast().getMappingId())]); @@ -243,7 +245,7 @@ sourceBlock.getOperations()); // Step 5. RAUW thread indices to thread ops. - for (Value loopIndex : foreachThreadOp.getThreadIndices()) { + for (Value loopIndex : foreachThreadOp.getInductionVars()) { Value blockIdx = bvm.lookup(loopIndex); rewriter.replaceAllUsesWith(loopIndex, blockIdx); } @@ -381,14 +383,16 @@ return emitDefiniteFailure(foreachThreadOp, message); }; Location loc = foreachThreadOp->getLoc(); + if (!foreachThreadOp.isNormalized()) + return failureHelper("unsupported non-normalized loops"); if (foreachThreadOp.getNumResults() > 0) return failureHelper( "only bufferized scf.foreach_thread lowers to gpu.thread_id"); - if (foreachThreadOp.getNumThreads().size() > 3) + if (foreachThreadOp.getRank() > 3) return failureHelper( "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id"); - if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) { - return !v.getDefiningOp(); + if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); })) { return failureHelper("unsupported dynamic blockdim size"); } @@ -399,8 +403,7 @@ // Step 1. Complete the threadMapping to a full mapping (with 1s) if // necessary. - SmallVector numThreads = - llvm::to_vector(foreachThreadOp.getNumThreads()); + SmallVector numThreads = foreachThreadOp.getUpperBound(rewriter); // Ensure we have 3 block sizes, one for each id. Value one; for (auto attr : threadMappingAttributes) { @@ -437,7 +440,7 @@ } IRMapping bvm; for (auto [blockIdx, blockDim] : - llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) { + llvm::zip(foreachThreadOp.getInductionVars(), threadMapping)) { bvm.map(blockIdx, threadOpsUpdated[blockDim.cast() .getMappingId()]); @@ -484,7 +487,7 @@ sourceBlock.getOperations()); // Step 6. RAUW thread indices to thread ops. - for (Value loopIndex : foreachThreadOp.getThreadIndices()) { + for (Value loopIndex : foreachThreadOp.getInductionVars()) { Value threadIdx = bvm.lookup(loopIndex); rewriter.replaceAllUsesWith(loopIndex, threadIdx); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -253,7 +253,7 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(foreachThreadOp.getBody(0)); - ValueRange threadIds = foreachThreadOp.getThreadIndices(); + ValueRange threadIds = foreachThreadOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); @@ -360,7 +360,7 @@ // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); + loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping); // 2. Fill out the ForeachThreadOp body. SmallVector tiledOffsets, tiledSizes; @@ -681,8 +681,8 @@ // 2. Create the ForeachThreadOp with an empty region. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, (*identityTensor)->getResults(), - ValueRange(materializedNonZeroNumThreads), mapping); + loc, getAsOpFoldResult(materializedNonZeroNumThreads), + (*identityTensor)->getResults(), mapping); // 3. Calculate the tile offsets and sizes for the subsequent loop that will // be nested under `foreachThreadOp`. @@ -712,7 +712,7 @@ b.getIndexAttr(0)); SmallVector sizes = tiledSizes; sizes[reductionDim] = b.getIndexAttr(1); - outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); + outOffsets[reductionDim] = foreachThreadOp.getInductionVars().front(); // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( loc, initOperand->get().getType().cast(), @@ -746,7 +746,7 @@ if (failed(maybeTiled)) return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); - SmallVector ids = foreachThreadOp.getThreadIndices(); + SmallVector ids = foreachThreadOp.getInductionVars(); mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, materializedNonZeroNumThreads); assert(maybeTiled->loops.size() == 1 && @@ -774,7 +774,7 @@ int64_t sizeIdx = 0; for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { if (i == reductionDim) { - resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front()); + resultOffsetsRank.push_back(foreachThreadOp.getInductionVars().front()); resultSizesRank.push_back(b.getIndexAttr(1)); continue; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1110,6 +1110,7 @@ //===----------------------------------------------------------------------===// LogicalResult ForeachThreadOp::verify() { + unsigned numLoops = getRank(); // Check number of outputs. if (getNumResults() != getOutputs().size()) return emitOpError("produces ") @@ -1118,18 +1119,18 @@ // Check that the body defines block arguments for thread indices and outputs. auto *body = getBody(); - if (body->getNumArguments() != getRank() + getOutputs().size()) - return emitOpError("region expects ") << getRank() << " arguments"; - for (int64_t i = 0; i < getRank(); ++i) + if (body->getNumArguments() != numLoops + getOutputs().size()) + return emitOpError("region expects ") << numLoops << " arguments"; + for (int64_t i = 0; i < numLoops; ++i) if (!body->getArgument(i).getType().isIndex()) return emitOpError("expects ") << i << "-th block argument to be an index"; for (unsigned i = 0; i < getOutputs().size(); ++i) - if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) + if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType()) return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; if (getMapping().has_value() && !getMapping()->empty()) { - if (static_cast(getMapping()->size()) != getRank()) + if (static_cast(getMapping()->size()) != numLoops) return emitOpError() << "mapping attribute size must match op rank"; for (auto map : getMapping()->getValue()) { if (!isa(map)) @@ -1138,15 +1139,41 @@ } } + // Verify mixed static/dynamic control variables. + Operation *op = getOperation(); + if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops, + getStaticLowerBound(), + getDynamicLowerBound()))) + return failure(); + if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops, + getStaticUpperBound(), + getDynamicUpperBound()))) + return failure(); + if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops, + getStaticStep(), getDynamicStep()))) + return failure(); + return success(); } void ForeachThreadOp::print(OpAsmPrinter &p) { - p << " ("; - llvm::interleaveComma(getThreadIndices(), p); - p << ") in ("; - llvm::interleaveComma(getNumThreads(), p); - p << ")"; + Operation *op = getOperation(); + p << " (" << getInductionVars(); + if (isNormalized()) { + p << ") in "; + printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), + OpAsmParser::Delimiter::Paren); + } else { + p << ") = "; + printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), + OpAsmParser::Delimiter::Paren); + p << " to "; + printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), + OpAsmParser::Delimiter::Paren); + p << " step "; + printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), + OpAsmParser::Delimiter::Paren); + } printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); p << " "; if (!getRegionOutArgs().empty()) @@ -1154,28 +1181,60 @@ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/getNumResults() > 0); - p.printOptionalAttrDict(getOperation()->getAttrs(), - {"operand_segment_sizes"}); + p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(), + getStaticLowerBoundAttrName(), + getStaticUpperBoundAttrName(), + getStaticStepAttrName()}); } ParseResult ForeachThreadOp::parse(OpAsmParser &parser, OperationState &result) { - auto &builder = parser.getBuilder(); + OpBuilder b(parser.getContext()); + auto indexType = b.getIndexType(); + // Parse an opening `(` followed by thread index variables followed by `)` // TODO: when we can refer to such "induction variable"-like handles from the // declarative assembly format, we can implement the parser as a custom hook. - SmallVector threadIndices; - if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren)) + SmallVector ivs; + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) return failure(); - // Parse `in` threadNums. - SmallVector threadNums; - if (parser.parseKeyword("in") || - parser.parseOperandList(threadNums, threadIndices.size(), + DenseI64ArrayAttr staticLbs, staticUbs, staticSteps; + SmallVector dynamicLbs, dynamicUbs, + dynamicSteps; + if (succeeded(parser.parseOptionalKeyword("in"))) { + // Parse upper bounds. + if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(threadNums, builder.getIndexType(), - result.operands)) - return failure(); + parser.resolveOperands(dynamicUbs, indexType, result.operands)) + return failure(); + + unsigned numLoops = ivs.size(); + staticLbs = b.getDenseI64ArrayAttr(SmallVector(numLoops, 0)); + staticSteps = b.getDenseI64ArrayAttr(SmallVector(numLoops, 1)); + } else { + // Parse lower bounds. + if (parser.parseEqual() || + parseDynamicIndexList(parser, dynamicLbs, staticLbs, + OpAsmParser::Delimiter::Paren) || + + parser.resolveOperands(dynamicLbs, indexType, result.operands)) + return failure(); + + // Parse upper bounds. + if (parser.parseKeyword("to") || + parseDynamicIndexList(parser, dynamicUbs, staticUbs, + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(dynamicUbs, indexType, result.operands)) + return failure(); + + // Parse step values. + if (parser.parseKeyword("step") || + parseDynamicIndexList(parser, dynamicSteps, staticSteps, + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(dynamicSteps, indexType, result.operands)) + return failure(); + } // Parse out operands and results. SmallVector regionOutArgs; @@ -1195,9 +1254,9 @@ // Parse region. SmallVector regionArgs; std::unique_ptr region = std::make_unique(); - for (auto &idx : threadIndices) { - idx.type = builder.getIndexType(); - regionArgs.push_back(idx); + for (auto &iv : ivs) { + iv.type = b.getIndexType(); + regionArgs.push_back(iv); } for (const auto &it : llvm::enumerate(regionOutArgs)) { auto &out = it.value(); @@ -1208,92 +1267,111 @@ return failure(); // Ensure terminator and move region. - OpBuilder b(builder.getContext()); ForeachThreadOp::ensureTerminator(*region, b, result.location); result.addRegion(std::move(region)); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + result.addAttribute("staticLowerBound", staticLbs); + result.addAttribute("staticUpperBound", staticUbs); + result.addAttribute("staticStep", staticSteps); result.addAttribute("operand_segment_sizes", parser.getBuilder().getDenseI32ArrayAttr( - {static_cast(threadNums.size()), + {static_cast(dynamicLbs.size()), + static_cast(dynamicUbs.size()), + static_cast(dynamicSteps.size()), static_cast(outOperands.size())})); return success(); } -// Bodyless builder, outputs must be specified. -void ForeachThreadOp::build(mlir::OpBuilder &builder, - mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, - std::optional mapping) { - result.addOperands(numThreads); +// Builder that takes loop bounds. +void ForeachThreadOp::build( + mlir::OpBuilder &b, mlir::OperationState &result, + ArrayRef lbs, ArrayRef ubs, + ArrayRef steps, ValueRange outputs, + std::optional mapping, + function_ref bodyBuilderFn) { + SmallVector staticLbs, staticUbs, staticSteps; + SmallVector dynamicLbs, dynamicUbs, dynamicSteps; + dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs); + dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs); + dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps); + + result.addOperands(dynamicLbs); + result.addOperands(dynamicUbs); + result.addOperands(dynamicSteps); result.addOperands(outputs); + result.addTypes(TypeRange(outputs)); + + result.addAttribute(getStaticLowerBoundAttrName(result.name), + b.getDenseI64ArrayAttr(staticLbs)); + result.addAttribute(getStaticUpperBoundAttrName(result.name), + b.getDenseI64ArrayAttr(staticUbs)); + result.addAttribute(getStaticStepAttrName(result.name), + b.getDenseI64ArrayAttr(staticSteps)); + result.addAttribute( + "operand_segment_sizes", + b.getDenseI32ArrayAttr({static_cast(dynamicLbs.size()), + static_cast(dynamicUbs.size()), + static_cast(dynamicSteps.size()), + static_cast(outputs.size())})); if (mapping.has_value()) { result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), mapping.value()); } - result.addAttribute( - "operand_segment_sizes", - builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), - static_cast(outputs.size())})); - result.addTypes(TypeRange(outputs)); - Region *bodyRegion = result.addRegion(); - OpBuilder::InsertionGuard g(builder); - // createBlock sets the IP inside the block. - // Generally we would guard against that but the default ensureTerminator impl - // expects it .. - builder.createBlock(bodyRegion); + OpBuilder::InsertionGuard g(b); + b.createBlock(bodyRegion); Block &bodyBlock = bodyRegion->front(); - // Add block arguments for indices and outputs. - bodyBlock.addArguments( - SmallVector(numThreads.size(), builder.getIndexType()), - SmallVector(numThreads.size(), result.location)); - bodyBlock.addArguments( - TypeRange(outputs), - SmallVector(outputs.size(), result.location)); - ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location); -} -// Builder that takes a bodyBuilder lambda. -void ForeachThreadOp::build( - mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, ArrayRef mapping, - function_ref bodyBuilder) { - result.addOperands(numThreads); - result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), - builder.getArrayAttr(mapping)); - result.addAttribute( - "operand_segment_sizes", - builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), - static_cast(outputs.size())})); - result.addTypes(TypeRange(outputs)); - - Region *bodyRegion = result.addRegion(); - OpBuilder::InsertionGuard g(builder); - builder.createBlock(bodyRegion); - Block &bodyBlock = bodyRegion->front(); // Add block arguments for indices and outputs. bodyBlock.addArguments( - SmallVector(numThreads.size(), builder.getIndexType()), - SmallVector(numThreads.size(), result.location)); + SmallVector(lbs.size(), b.getIndexType()), + SmallVector(staticLbs.size(), result.location)); bodyBlock.addArguments( TypeRange(outputs), SmallVector(outputs.size(), result.location)); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArguments()); + b.setInsertionPointToStart(&bodyBlock); + if (!bodyBuilderFn) { + ForeachThreadOp::ensureTerminator(*bodyRegion, b, result.location); + return; + } + bodyBuilderFn(b, result.location, bodyBlock.getArguments()); #ifndef NDEBUG auto terminator = llvm::dyn_cast(bodyBlock.getTerminator()); assert(terminator && - "expected bodyBuilder to create PerformConcurrentlyOp terminator"); + "expected bodyBuilderFn to create PerformConcurrentlyOp terminator"); #endif // NDEBUG } +// Builder that takes loop bounds. +void ForeachThreadOp::build( + mlir::OpBuilder &b, mlir::OperationState &result, + ArrayRef ubs, ValueRange outputs, + std::optional mapping, + function_ref bodyBuilderFn) { + unsigned numLoops = ubs.size(); + SmallVector lbs(numLoops, b.getIndexAttr(0)); + SmallVector steps(numLoops, b.getIndexAttr(1)); + build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn); +} + +// Checks if the lbs are zeros and steps are ones. +bool ForeachThreadOp::isNormalized() { + auto allEqual = [](ArrayRef results, int64_t val) { + return llvm::all_of(results, [&](OpFoldResult ofr) { + auto intValue = getConstantIntValue(ofr); + return intValue.has_value() && intValue == val; + }); + }; + return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); +} + // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be // well-formed. We override it here to ensure that we do the right thing. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1004,15 +1004,14 @@ /// Return `true` if the given loop may have 0 iterations. bool mayHaveZeroIterations(scf::ForeachThreadOp foreachThreadOp) { - int64_t p = 1; - for (Value v : foreachThreadOp.getNumThreads()) { - if (std::optional c = getConstantIntValue(v)) { - p *= *c; - } else { + for (auto [lb, ub] : llvm::zip(foreachThreadOp.getMixedLowerBound(), + foreachThreadOp.getMixedUpperBound())) { + std::optional lbConst = getConstantIntValue(lb); + std::optional ubConst = getConstantIntValue(ub); + if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst) return true; - } } - return p == 0; + return false; } /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the @@ -1087,8 +1086,9 @@ rewriter.setInsertionPoint(foreachThreadOp); ForeachThreadOp newForeachThreadOp; newForeachThreadOp = rewriter.create( - foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), - foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping()); + foreachThreadOp.getLoc(), foreachThreadOp.getMixedLowerBound(), + foreachThreadOp.getMixedUpperBound(), foreachThreadOp.getMixedStep(), + /*outputs=*/ValueRange(), foreachThreadOp.getMapping()); newForeachThreadOp.getBody()->getTerminator()->erase(); @@ -1127,10 +1127,28 @@ bool isRepetitiveRegion(Operation *op, unsigned index) const { auto foreachThreadOp = cast(op); - // This op is not repetitive if it has just a single thread. - return !llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) { - return getConstantIntValue(v) == static_cast(1); - }); + + // This op is repetitive if it has 1 or more steps. + // If the control variables are dynamic, it is also considered so. + for (auto [lb, ub, step] : llvm::zip(foreachThreadOp.getMixedLowerBound(), + foreachThreadOp.getMixedUpperBound(), + foreachThreadOp.getMixedStep())) { + std::optional lbConstant = getConstantIntValue(lb); + if (!lbConstant) + return true; + + std::optional ubConstant = getConstantIntValue(ub); + if (!ubConstant) + return true; + + std::optional stepConstant = getConstantIntValue(step); + if (!stepConstant) + return true; + + if (*lbConstant + *stepConstant < *ubConstant) + return true; + } + return false; } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -180,10 +180,10 @@ if (scf::ForeachThreadOp foreachThreadOp = scf::getForeachThreadOpThreadIndexOwner(iv)) { for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { - if (foreachThreadOp.getThreadIndices()[idx] == iv) { - lb = OpBuilder(iv.getContext()).getIndexAttr(0); - ub = foreachThreadOp.getNumThreads()[idx]; - step = OpBuilder(iv.getContext()).getIndexAttr(1); + if (foreachThreadOp.getInductionVar(idx) == iv) { + lb = foreachThreadOp.getMixedLowerBound()[idx]; + ub = foreachThreadOp.getMixedUpperBound()[idx]; + step = foreachThreadOp.getMixedStep()[idx]; return success(); } } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -69,12 +69,45 @@ return success(); } +static char getLeftDelimiter(AsmParser::Delimiter delimiter) { + switch (delimiter) { + case AsmParser::Delimiter::Paren: + return '('; + case AsmParser::Delimiter::LessGreater: + return '<'; + case AsmParser::Delimiter::Square: + return '['; + case AsmParser::Delimiter::Braces: + return '{'; + default: + llvm_unreachable("unsupported delimiter"); + } +} + +static char getRightDelimiter(AsmParser::Delimiter delimiter) { + switch (delimiter) { + case AsmParser::Delimiter::Paren: + return ')'; + case AsmParser::Delimiter::LessGreater: + return '>'; + case AsmParser::Delimiter::Square: + return ']'; + case AsmParser::Delimiter::Braces: + return '}'; + default: + llvm_unreachable("unsupported delimiter"); + } +} + void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, - ArrayRef integers) { - printer << '['; + ArrayRef integers, + AsmParser::Delimiter delimiter) { + char leftDelimiter = getLeftDelimiter(delimiter); + char rightDelimiter = getRightDelimiter(delimiter); + printer << leftDelimiter; if (integers.empty()) { - printer << "]"; + printer << rightDelimiter; return; } unsigned idx = 0; @@ -84,13 +117,13 @@ else printer << integer; }); - printer << ']'; + printer << rightDelimiter; } ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers) { + DenseI64ArrayAttr &integers, AsmParser::Delimiter delimiter) { SmallVector integerVals; auto parseIntegerOrValue = [&]() { @@ -107,8 +140,7 @@ } return success(); }; - if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, - parseIntegerOrValue, + if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, " in dynamic index list")) return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -13,9 +13,7 @@ // CHECK-SAME: %[[B:[0-9a-z]+]]: tensor // CHECK-SAME: %[[C:[0-9a-z]+]]: tensor func.func @matmul(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index - // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index - // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor) { + // CHECK: scf.foreach_thread ({{.*}}) in (10, 20) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor) { // CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor to tensor // CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor to tensor // CHECK: %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor to tensor @@ -95,9 +93,7 @@ // CHECK-SAME: %[[B:[0-9a-z]+]]: tensor // CHECK-SAME: %[[C:[0-9a-z]+]]: tensor func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { - // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index - // CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 21) shared_outs(%[[C_BLK:.*]] = %[[C]]) // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]]) // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) // CHECK-NOT: affine.min @@ -175,9 +171,7 @@ // CHECK-SAME: %[[B:[0-9a-z]+]]: tensor // CHECK-SAME: %[[C:[0-9a-z]+]]: tensor func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> { - // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : - // CHECK-DAG: %[[c15:.+]] = arith.constant 15 : - // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]]) // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]]) // CHECK-NOT: affine.max // CHECK-NOT: affine.min @@ -225,8 +219,7 @@ // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: extract_source( -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) { +// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (2) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) { // CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]]) // CHECK: scf.foreach_thread.perform_concurrently { // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32> @@ -289,8 +282,7 @@ func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>) -> (tensor<100xf32>, tensor<100xf32>) { -// CHECK-DAG: %[[c0:.+]] = arith.constant 7 : -// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) +// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]]) // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) // CHECK-NOT: affine.min @@ -345,8 +337,7 @@ func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>, %OUT1: tensor<300x100xf32>, %OUT2: tensor<300xf32>) -> (tensor<300x100xf32>, tensor<300xf32>) { -// CHECK-DAG: %[[c0:.+]] = arith.constant 4 : -// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) +// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (4) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) // CHECK: %[[LB:.+]] = affine.apply #[[$map0]](%[[IV0]]) // CHECK: %[[tIN1:.+]] = tensor.extract_slice %[[IN2]][0, %[[LB]]] [100, 75] // CHECK: %[[tIN2:.+]] = tensor.extract_slice %[[IN3]][%[[LB]]] [75] diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -122,13 +122,12 @@ // CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor // CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor -// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor @@ -175,7 +174,6 @@ // CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor @@ -183,7 +181,7 @@ // CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor // CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor -// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor @@ -235,13 +233,12 @@ // CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor // CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor -// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -311,8 +311,8 @@ return %res : i64 } -// CHECK-LABEL: func.func @simple_example -func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) { +// CHECK-LABEL: func.func @normalized_foreach_thread +func.func @normalized_foreach_thread(%in: tensor<100xf32>, %out: tensor<100xf32>) { %c1 = arith.constant 1 : index %num_threads = arith.constant 100 : index @@ -333,8 +333,32 @@ return } -// CHECK-LABEL: func.func @elide_terminator -func.func @elide_terminator() -> () { +// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread +func.func @explicit_loop_bounds_foreach_thread(%in: tensor<100xf32>, + %out: tensor<100xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread + // CHECK-NEXT: tensor.extract_slice + // CHECK-NEXT: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + %result = scf.foreach_thread (%thread_idx) = (%c0) to (%num_threads) step (%c1) shared_outs(%o = %out) -> tensor<100xf32> { + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : + tensor<1xf32> into tensor<100xf32> + } + } + return +} + +// CHECK-LABEL: func.func @normalized_foreach_thread_elide_terminator +func.func @normalized_foreach_thread_elide_terminator() -> () { %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread @@ -345,6 +369,23 @@ } } {mapping = [#gpu.thread]} return + +} + +// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread_elide_terminator +func.func @explicit_loop_bounds_foreach_thread_elide_terminator() -> () { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread + // CHECK-NEXT: } {mapping = [#gpu.thread]} + // CHECK-NEXT: return + scf.foreach_thread (%thread_idx) = (%c0) to (%num_threads) step (%c1) { + scf.foreach_thread.perform_concurrently { + } + } {mapping = [#gpu.thread]} + return } // CHECK-LABEL: @switch diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -24,12 +24,11 @@ // CHECK: return %[[tile]] // FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: -// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index // FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index // FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index // FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index // FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32> -// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]]) +// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]]) // FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] // FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : // FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -247,9 +247,10 @@ tensor::ExtractSliceFromCollapseHelper &helper, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto foreachOp = rewriter.create( - loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), - /*mapping=*/ArrayRef{}, + auto foreachThreadOp = rewriter.create( + loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), + /*outputs=*/dest, + /*mapping=*/std::nullopt, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { unsigned numThreadIdRegionArgs = helper.getIterationSpaceSizes().size(); @@ -267,7 +268,7 @@ nestedBuilder.create( loc, tile, outputArgs[0], insertParams); }); - rewriter.replaceOp(op, foreachOp->getResult(0)); + rewriter.replaceOp(op, foreachThreadOp->getResult(0)); return success(); } };