diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpImplementation.h" namespace mlir { +class TilingInterface; namespace linalg { class GenericOp; class LinalgOp; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -554,6 +554,60 @@ }]; } +def TileToForeachThreadOp : + Op { + let description = [{ + Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying + tiling by `num_threads`. + If non-empty, the `thread_dim_mapping` is added as an attribute to the + resulting `scf.foreach_thread`. + Zero tile sizes indicate that the dimension is not tiled, and can be thought + of as tiling by the full size of data. + It is the user's responsibility to ensure that `num_threads` is a valid + tiling specification (i.e. that only tiles parallel dimensions, e.g. in the + Linalg case). + + #### Return modes + + This operation ignores ops that do not implement the TilingInterface and + drops them in the return. + + If all the operations referred to by the `target` PDLOperation tile + successfully, the transform succeeds. + Otherwise the transform silently fails. + + The 2 returned handles point to only the subset of successfully produced + tiled operations, which can all be empty. + + These 2 returned handles point to: + - the new scf.foreach_thread op, + - the tiled op that implements TilingInterface. + }]; + + let arguments = (ins PDL_Operation:$target, + // TODO: dynamic number of threads. + DefaultValuedAttr:$num_threads, + OptionalAttr:$thread_dim_mapping); + let results = (outs PDL_Operation:$foreach_thread_op, + PDL_Operation:$tiled_op); + + let assemblyFormat = [{ + $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::TilingInterface target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -517,6 +517,24 @@ OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions = true); +/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying +/// tiling by `numThreads`. +/// If non-empty, the `threadDimMapping` is added as an attribute to the +/// resulting `scf.foreach_thread`. +/// Zero tile sizes indicate that the dimension is not tiled, and can be thought +/// of as tiling by the full size of data. +/// It is the user's responsibility to ensure that `numThreads` is a +/// valid tiling specification (i.e. that only tiles parallel +/// dimensions, e.g. in the Linalg case). +struct ForeachThreadTilingResult { + Operation *tileOp; + Operation *tiledOp; +}; +FailureOr +tileToForeachThreadOp(OpBuilder &builder, TilingInterface op, + ArrayRef numThreads, + ArrayRef threadDimMapping = {}); + /// All indices returned by IndexOp should be invariant with respect to tiling. /// Therefore, if an operation is tiled, we have to transform the indices /// accordingly, i.e. offset them by the values of the corresponding induction diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -50,6 +50,9 @@ /// value. If this fails, return the original value. SmallVector getAsOpFoldResult(ArrayRef values); +/// Convert `arrayAttr` to a vector of OpFoldResult. +SmallVector getAsOpFoldResult(ArrayAttr arrayAttr); + /// If ofr is a constant integer or an IntegerAttr, return the integer. Optional getConstantIntValue(OpFoldResult ofr); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,16 +24,6 @@ using namespace mlir::linalg; using namespace mlir::transform; -/// Extracts a vector of int64_t from an array attribute. Asserts if the -/// attribute contains values other than integers. -static SmallVector extractI64Array(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getSExtValue()); - return result; -} - /// Extracts a vector of unsigned from an array attribute. Asserts if the /// attribute contains values other than intergers. May truncate. static SmallVector extractUIntArray(ArrayAttr attr) { @@ -160,7 +151,8 @@ << "'" << sizesAttrName << "' attribute must be an array"; Type pdlOpType = parser.getBuilder().getType(); size_t numExpectedLoops = - sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); + sizesArrayAttr.size() - + llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) return failure(); @@ -171,8 +163,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { LinalgTilingAndFusionOptions fusionOptions; - fusionOptions.tileSizes = extractI64Array(getTileSizes()); - fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); + fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); + fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); LogicalResult result = applyTilingToAll( getOperation(), state.getPayloadOps(getTarget()), @@ -209,7 +201,8 @@ } LogicalResult transform::FuseOp::verify() { - SmallVector permutation = extractI64Array(getTileInterchange()); + SmallVector permutation = + extractFromI64ArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { @@ -327,7 +320,7 @@ transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; - for (int64_t packPadding : extractI64Array(getPackPaddings())) + for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. @@ -362,13 +355,14 @@ SmallVector> transposePaddings; for (Attribute transposeVector : getTransposePaddings().cast()) transposePaddings.push_back( - extractI64Array(transposeVector.cast())); + extractFromI64ArrayAttr(transposeVector.cast())); LinalgPaddingOptions paddingOptions; paddingOptions.setPaddingValues(paddingValues); - paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); + paddingOptions.setPaddingDimensions( + extractFromI64ArrayAttr(getPaddingDimensions())); paddingOptions.setPackPaddings(packPaddings); - paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); + paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); FailureOr result = @@ -383,7 +377,8 @@ } LogicalResult transform::PadOp::verify() { - SmallVector packPaddings = extractI64Array(getPackPaddings()); + SmallVector packPaddings = + extractFromI64ArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { @@ -393,7 +388,7 @@ } SmallVector paddingDimensions = - extractI64Array(getPaddingDimensions()); + extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() @@ -401,7 +396,8 @@ << getPaddingDimensions(); } - SmallVector hoistPaddings = extractI64Array(getHoistPaddings()); + SmallVector hoistPaddings = + extractFromI64ArrayAttr(getHoistPaddings()); if (any_of(hoistPaddings, [](int64_t hoistPadding) { return hoistPadding < 0; })) { return emitOpError() @@ -657,7 +653,7 @@ transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; - SmallVector tileSizes = extractI64Array(getStaticSizes()); + SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; @@ -743,7 +739,7 @@ SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); - SmallVector tileSizes = extractI64Array(getStaticSizes()); + SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; @@ -773,7 +769,7 @@ result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = - staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0); + staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } @@ -794,6 +790,29 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// TileToForeachThreadOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( + TilingInterface target, SmallVectorImpl &results, + transform::TransformState &state) { + IRRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + auto maybeThreadDimMappingAttr = getThreadDimMapping(); + FailureOr tilingResult = + linalg::tileToForeachThreadOp( + rewriter, target, getAsOpFoldResult(getNumThreads()), + maybeThreadDimMappingAttr + ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef{}); + if (failed(tilingResult)) + return emitDefaultSilenceableFailure(target); + rewriter.replaceOp(target, tilingResult->tileOp->getResults()); + results.assign({tilingResult->tileOp, tilingResult->tiledOp}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -169,6 +169,135 @@ return spec; } +/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new +/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location +/// as `subsetExtractOp`. +static void +createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, + tensor::ExtractSliceOp subsetExtractOp, + Value source, Value dest) { + b.create( + loc, source, dest, subsetExtractOp.getMixedOffsets(), + subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); +} + +/// Build an `affine_max` of all the `vals`. +static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) { + return b.createOrFold( + loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +/// Build an `affine_min` of all the `vals`. +static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) { + return b.createOrFold( + loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); +} + +FailureOr +linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op, + ArrayRef numThreads, + ArrayRef threadDimMapping) { + Location loc = op->getLoc(); + OpBuilder::InsertionGuard g(b); + SmallVector loopRanges = op.getIterationDomain(b); + if (loopRanges.empty()) + return op->emitOpError("expected non-empty loop ranges"); + auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; + if (llvm::any_of(loopRanges, hasStrideOne)) + return op->emitOpError("only stride-1 supported atm"); + // TODO: support `getTiledImplementation` with >1 produced tiled ops. + auto destOperands = op.getDestinationOperands(b); + if (destOperands.size() != 1) + return op->emitOpError("only single dest operand supported atm"); + + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { + ImplicitLocOpBuilder ilocb(loc, b); + return materializeOpFoldResult(ilocb, ofr); + })); + + Value zero = b.create(loc, 0); + Operation *tiledOp = nullptr; + scf::ForeachThreadOp foreachThreadOp = b.create( + loc, materializedNonZeroNumThreads, threadDimMapping, + [&](OpBuilder &b, Location loc, ValueRange threadIds) { + int64_t nLoops = loopRanges.size(); + SmallVector tiledOffsets, tiledSizes; + tiledOffsets.reserve(nLoops); + tiledSizes.reserve(nLoops); + for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; + ++loopIdx) { + bool overflow = loopIdx >= numThreads.size(); + bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + // Degenerate case: take the whole domain. + if (overflow || isZero) { + tiledOffsets.push_back(loopRanges[loopIdx].offset); + tiledSizes.push_back(loopRanges[loopIdx].size); + continue; + } + + // Tiled case: compute the offset and size. + AffineExpr i, j, M, N, O; + bindDims(b.getContext(), i, j); + bindSymbols(b.getContext(), M, N, O); + Value size = loopRanges[loopIdx].size; + Value offset = loopRanges[loopIdx].offset; + Value threadId = threadIds[threadIdIdx]; + // TODO: more aggressive foldings. + // Symbolic fixed max size per thread. + // TODO: floor + 0/1 depending on case for better load-balancing. + Value maxSizePerThread = b.createOrFold( + loc, M.ceilDiv(N), + ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]}); + // Dynamic offset shifted by threadId * maxSizePerThread. + Value offsetPerThread = b.createOrFold( + loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread}); + // Dynamic upper-bound depending on the threadId. + Value sizeMinusOffsetPerThread = b.createOrFold( + loc, -i + M, ValueRange{offsetPerThread, size}); + Value tileSizePerThread = buildMin( + b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread}); + tiledOffsets.push_back(offsetPerThread); + // TODO: if tileSizePerThread <= 0 early exit. + tiledSizes.push_back( + buildMax(b, loc, ValueRange{zero, tileSizePerThread})); + ++threadIdIdx; + } + + SmallVector tiledOps = + op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, + /*tileDestOperands=*/true); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + + auto tilingInterfaceOp = dyn_cast(tiledOp); + assert(tilingInterfaceOp && + "Tiled op does not implement TilingInterface"); + + auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); + + // Create terminator with parallel subset insert operations. + auto performConcurrentlyOp = b.create(loc); + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(performConcurrentlyOp.getBody()); + for (auto it : + llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), + destOperands)) { + createMatchingParallelSubsetInsertOp( + b, loc, + cast(std::get<0>(it).getDefiningOp()), + std::get<1>(it), std::get<2>(it)); + } + }); + return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; +} + // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -65,6 +65,15 @@ llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); })); } +/// Convert `arrayAttr` to a vector of OpFoldResult. +SmallVector getAsOpFoldResult(ArrayAttr arrayAttr) { + SmallVector res; + res.reserve(arrayAttr.size()); + for (Attribute a : arrayAttr) + res.push_back(a); + return res; +} + /// If ofr is a constant integer or an IntegerAttr, return the integer. Optional getConstantIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer. diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s + +// Offset per thread: +// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))> +// Per thread tile size. +// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)> +// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))> +// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)> + +module { +// CHECK-LABEL: matmul( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor + func.func @matmul(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index + // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index + // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor) { + // CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor to tensor + // CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor to tensor + // CHECK: %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor to tensor + // CHECK: %[[RES:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor, tensor) + // CHECK-SAME: outs(%[[tC]] : tensor) -> tensor + // CHECK-NEXT: scf.foreach_thread.perform_concurrently { + // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} : + // CHECK-SAME: tensor into tensor + // CHECK-NEXT: } + // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @match_linalg_matmul : benefit(1) { + %0 = operands + %1 = types + %2 = operation "linalg.matmul"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_linalg_matmul in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0]) + } + } +}