diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -796,7 +796,7 @@ scf.foreach_thread.perform_concurrently { tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor into tensor } - } {thread_dim_mapping = []} + } {mapping = []} %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor) outs(%arg1 : tensor) { ^bb0(%in: f32, %out: f32): %4 = arith.addf %in, %out : f32 @@ -807,7 +807,8 @@ let arguments = (ins PDL_Operation:$target, DefaultValuedAttr:$num_threads, - DefaultValuedAttr:$tile_sizes); + DefaultValuedAttr:$tile_sizes, + OptionalAttr:$mapping); let results = (outs PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1222,7 +1222,7 @@ FailureOr result = linalg::tileReductionUsingForeachThread( rewriter, cast(target.getOperation()), - numThreads, tileSizes, /*mapping=*/std::nullopt); + numThreads, tileSizes, getMapping()); if (failed(result)) { results.assign(3, nullptr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -25,8 +25,10 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include @@ -221,6 +223,9 @@ Optional> nominalTileSizes, SmallVector &tiledOffsets, SmallVector &tiledSizes) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + ValueRange threadIds = foreachThreadOp.getThreadIndices(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { @@ -300,6 +305,7 @@ Optional mapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); + SmallVector loopRanges = op.getIterationDomain(b); if (loopRanges.empty()) return op->emitOpError("expected non-empty loop ranges"); @@ -330,7 +336,6 @@ loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); // Fill out the ForeachThreadOp body. - b.setInsertionPointToStart(foreachThreadOp.getBody(0)); SmallVector tiledOffsets, tiledSizes; calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges, omitTileOffsetBoundsCheck, nominalTileSizes, @@ -361,16 +366,15 @@ auto tilingInterfaceOp = dyn_cast(tiledOp); assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); - OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), tilingInterfaceOp->getResults(), destBbArgs)) { - b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); SmallVector resultOffsets, resultSizes; if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets, resultSizes))) return op->emitOpError("output offsets couldn't be calculated"); SmallVector strides(resultSizes.size(), b.getIndexAttr(1)); + OpBuilder::InsertionGuard g(b); b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); b.create(loc, std::get<1>(it), std::get<2>(it), resultOffsets, @@ -415,6 +419,8 @@ static FailureOr tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + auto nLoops = op.getNumLoops(); // Initial tile sizes may be too big, only take the first nLoops. tileSizes = tileSizes.take_front(nLoops); @@ -570,17 +576,35 @@ Optional mapping) { Location loc = op.getLoc(); OpBuilder::InsertionGuard g(b); + // Ops implementing PartialReductionOpInterface are expected to implement // TilingInterface. + // TODO: proper core mechanism to tie interfaces together. auto tilingInterfaceOp = cast(op.getOperation()); + + // Ops implementing PartialReductionOpInterface are not necessarily expected + // to implement TilingInterface.. This cast is unsafe atm. + // TODO: proper core mechanism to tie interfaces together. + // TODO: this function requires a pair of interfaces .. + auto destinationStyleOp = + dyn_cast(op.getOperation()); + if (!destinationStyleOp) + return b.notifyMatchFailure(op, "not a destination style op"); + + // Actually this only work for Linalg ops atm. + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp) + return b.notifyMatchFailure(op, "not a linalg op"); + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); if (op->getNumResults() != 1) return b.notifyMatchFailure( op, "don't support ops with multiple results for now"); + SmallVector iterators = tilingInterfaceOp.getLoopIteratorTypes(); SmallVector redDims; - cast(op.getOperation()).getReductionDims(redDims); + linalgOp.getReductionDims(redDims); if (redDims.size() != 1) return b.notifyMatchFailure( op, "only support ops with one reduction dimension."); @@ -588,7 +612,8 @@ return b.notifyMatchFailure(op, "if tile sizes are present it must have as " "many elements as number of threads"); int reductionDim = static_cast(redDims.front()); - // 1. create the inital tensor value. + + // 1. Create the inital tensor value. FailureOr identityTensor = op.generateInitialTensorForPartialReduction(b, loc, numThreads, reductionDim); @@ -615,8 +640,8 @@ loc, identityTensor.value()->getResults(), ValueRange(materializedNonZeroNumThreads), mapping); - // 3. calculate the tile offsets and sizes. - b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + // 3. Calculate the tile offsets and sizes for the subsequent loop that will + // be nested under `foreachThreadOp`. SmallVector tiledOffsets, tiledSizes; calculateTileOffsetsAndSizes( b, loc, foreachThreadOp, numThreads, iterationDomain, @@ -625,54 +650,77 @@ // 4. Clone the tileable op and update its destination operands to use the // output bbArgs of the ForeachThreadOp. + ValueRange tilingResults; ArrayRef destBbArgs = foreachThreadOp.getOutputBlockArguments(); - Operation *clonedOp = b.clone(*op.getOperation()); - b.setInsertionPointToStart(foreachThreadOp.getBody(0)); - auto destinationStyleOp = cast(clonedOp); - for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) { - auto *it = llvm::find(dest, initOperand->get()); - assert(it != dest.end() && "dest operand not found in dest"); - unsigned destNum = std::distance(dest.begin(), it); - SmallVector strides(numThreads.size(), b.getIndexAttr(1)); - SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); - SmallVector sizes = tiledSizes; - sizes[reductionDim] = b.getIndexAttr(1); - outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); - // TODO: use SubsetExtractOpInterface once it is available. - Value patial = b.create( - loc, initOperand->get().getType().cast(), - destBbArgs[destNum], outOffsets, sizes, strides); - initOperand->set(patial); - } - b.setInsertionPoint(clonedOp); - - // 5. Tile the cloned op and delete the clone. - if (tileSizes.empty()) { - SmallVector tiledOps = - cast(clonedOp).getTiledImplementation(b, tiledOffsets, - tiledSizes); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - } else { - LinalgTilingOptions options; - auto tiled = tileLinalgOpImpl(b, cast(clonedOp), - tileSizes, options); - SmallVector ids = foreachThreadOp.getThreadIndices(); - mapLoopToProcessorIds(cast(tiled->loops.back()), ids, - materializedNonZeroNumThreads); - assert(tiled->loops.size() == 1 && "expected a single produced loop"); - tiledOp = tiled->loops.front(); + { + // 4.a. RAII guard, inserting within foreachThreadOp, before terminator. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(foreachThreadOp.getTerminator()); + + SmallVector tiledDpsInitOperands; + for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) { + auto *it = llvm::find(dest, initOperand->get()); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + SmallVector strides(numThreads.size(), b.getIndexAttr(1)); + SmallVector outOffsets(numThreads.size(), + b.getIndexAttr(0)); + SmallVector sizes = tiledSizes; + sizes[reductionDim] = b.getIndexAttr(1); + outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); + // TODO: use SubsetExtractOpInterface once it is available. + tiledDpsInitOperands.push_back(b.create( + loc, initOperand->get().getType().cast(), + destBbArgs[destNum], outOffsets, sizes, strides)); + } + + // 4.b. Clone the op and update init operands. + // We cannot use a BlockAndValueMapping here because it can replace + // different OpOperands with the same value. + Operation *clonedOp = b.clone(*op.getOperation()); + b.updateRootInPlace(clonedOp, [&]() { + for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( + cast(clonedOp).getDpsInitOperands(), + tiledDpsInitOperands)) { + initOperandPtr->set(tiledInitValue); + } + }); + + // 5. Tile the cloned op and delete the clone. + if (tileSizes.empty()) { + SmallVector tiledOps = + cast(clonedOp).getTiledImplementation( + b, tiledOffsets, tiledSizes); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + tilingResults = tiledOp->getResults(); + } else { + LinalgTilingOptions options; + FailureOr maybeTiled = tileLinalgOpImpl( + b, cast(clonedOp), tileSizes, options); + if (failed(maybeTiled)) + return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); + + SmallVector ids = foreachThreadOp.getThreadIndices(); + mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, + materializedNonZeroNumThreads); + assert(maybeTiled->loops.size() == 1 && + "expected a single produced loop"); + tiledOp = maybeTiled->op; + tilingResults = maybeTiled->loops.front()->getResults(); + } + + b.eraseOp(clonedOp); } - b.eraseOp(clonedOp); // 6. Insert the partial reductions back into a new tensor. - b.setInsertionPointAfter(tiledOp); - OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); - for (auto [index, result, bbArg] : - llvm::zip(llvm::seq(0, dest.size()), tiledOp->getResults(), - destBbArgs)) { - b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); + for (auto [index, result, bbArg] : llvm::zip( + llvm::seq(0, dest.size()), tilingResults, destBbArgs)) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(foreachThreadOp.getTerminator()); + + // 6.a. Partial subset information is inserted just before the terminator. SmallVector resultOffsets, resultSizes; if (failed(tilingInterfaceOp.getResultTilePosition( b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) @@ -689,18 +737,23 @@ resultOffsetsRank.push_back(resultOffsets[offIdx++]); resultSizesRank.push_back(resultSizes[sizeIdx++]); } - SmallVector strides(resultSizesRank.size(), b.getIndexAttr(1)); + + // 6.b. Parallel insertions are inserted at the end of the combining + // terminator. b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); b.create( loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); } + // 7. Merge the partial reductions. b.setInsertionPointAfter(foreachThreadOp); Operation *mergeOp = op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim); b.replaceOp(op, mergeOp->getResults()); + + // 8. Return. ForeachThreadReductionTilingResult results; results.initialOp = identityTensor.value(); results.loops = foreachThreadOp; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -874,19 +874,19 @@ DiagnosedSilenceableFailure transform::PrintOp::apply(transform::TransformResults &results, transform::TransformState &state) { - llvm::errs() << "[[[ IR printer: "; + llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) - llvm::errs() << *getName() << " "; + llvm::outs() << *getName() << " "; if (!getTarget()) { - llvm::errs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; + llvm::outs() << "top-level ]]]\n" << *state.getTopLevel() << "\n"; return DiagnosedSilenceableFailure::success(); } - llvm::errs() << "]]]\n"; + llvm::outs() << "]]]\n"; ArrayRef targets = state.getPayloadOps(getTarget()); for (Operation *target : targets) - llvm::errs() << *target << "\n"; + llvm::outs() << *target << "\n"; return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -218,7 +218,8 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] } + %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 + { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] } } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)> @@ -262,3 +263,41 @@ // CHECK: linalg.yield // CHECK: } -> tensor // CHECK: return %[[R]] : tensor + +// ----- + +func.func @reduction_tile_parallel_cyclic_dist( + %arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 + { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] } + + // CHECK: expecting fill + // CHECK-NEXT: linalg.fill + transform.print %1 {name = "expecting fill"} : !pdl.operation + // CHECK: expecting parallel reduction + // CHECK-NEXT: linalg.generic + // CHECK: parallel + // CHECK-SAME: reduction + transform.print %2 {name = "expecting parallel reduction"} : !pdl.operation + // CHECK: expecting parallel reduction + // CHECK-NEXT: linalg.generic + // CHECK: parallel + // CHECK-SAME: reduction + transform.print %3 {name = "expecting parallel reduction"} : !pdl.operation +}