diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -17,11 +17,16 @@ namespace mlir { class TilingInterface; class RewriterBase; + namespace linalg { class GenericOp; class LinalgOp; } // namespace linalg +namespace tensor { +class PackOp; +} // namespace tensor + namespace transform { class TransformHandleTypeInterface; // Types needed for builders. diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -215,6 +215,43 @@ }]; } +//===----------------------------------------------------------------------===// +// LowerPackOp +//===----------------------------------------------------------------------===// +def LowerPackOp : Op { + let description = [{ + Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + transpose. + + #### Return modes + + This operation ignores non-pack ops and drops them in the return. + This operation produces a silenceableFailure if the padding fails for any + reason. + If all the operations referred to by the `target` are rewritten, the + transform succeeds. + Return handles to the newly produced pad, expand_shape and transpose ops. + }]; + + let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target); + let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, + Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, + Transform_ConcreteOpType<"linalg.generic">:$transpose_op); + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::tensor::PackOp target, + ::mlir::transform::ApplyToEachResultList &transformResults, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // MatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1142,12 +1142,17 @@ GenericOp genericOp, ArrayRef foldedIterationDims, RewriterBase &rewriter); +/// Struct to hold the result of a `packTranspose` call. +struct PackResult { + SmallVector packOps; + linalg::LinalgOp packedLinalgOp; + SmallVector unPackOps; +}; /// Implement packing of a single LinalgOp by `packedSizes`. /// There must be one packedSizes entry per `linalgOp` iterator. /// Return the packed Linalg op on success, failure otherwise. -FailureOr pack(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - ArrayRef packedSizes); +FailureOr pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + ArrayRef packedSizes); /// Struct to hold the result of a `packTranspose` call. struct PackTransposeResult { diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,7 +118,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold); -/// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using +/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using /// `transposeVector` to permute the `inputTensor` dimensions. GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, Value outputTensor, diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1113,7 +1113,13 @@ }]> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + static RankedTensorType + inferCollapsedType(RankedTensorType type, ArrayRef reassociation); + static RankedTensorType + inferCollapsedType(RankedTensorType type, + SmallVector reassociation); + }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -14,10 +14,12 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" @@ -30,11 +32,11 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" @@ -131,6 +133,73 @@ return DiagnosedSilenceableFailure::success(); } +/// Return a permutation vector of size permSize that would result in moving +/// positions into desiredPositions. +/// +/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} +/// would result in a {4, 2, 0, 1, 3} permutation vector. +static SmallVector +computePermutationVector(int64_t permSize, ArrayRef positions, + ArrayRef desiredPositions) { + SmallVector res(permSize, -1); + DenseSet seen; + for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) { + res[desiredPos] = pos; + seen.insert(pos); + } + int64_t nextPos = 0; + for (int64_t &entry : res) { + if (entry != -1) + continue; + while (seen.contains(nextPos)) + ++nextPos; + entry = nextPos; + ++nextPos; + } + return res; +} + +struct PackingMetadata { + SmallVector insertPositions; + SmallVector reassociations; +}; +/// Given a vector of `positions` indices representing desired packing insertion +/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the +/// final positions in the target shape as well as the reshape reassociations. +// Note: This should not be called with a large positions array (or the +// implementation needs to be updated to use an N.log N sort instead of +// repeated N^2 counts). +static PackingMetadata computePackingMetadata(int64_t packedRank, + ArrayRef innerDimPos) { + PackingMetadata res; + res.insertPositions.reserve(innerDimPos.size()); + // The pack insert position is the position + the number of previously + // inserted positions + offset. + // The offset controls wether the packing dimension is the first or last. + // For example: if we packed 30 by 8, an offset of 0 would yield a 8x4 shape + // whereas an offset of 1 yields a 4x8 shape. + // The latter is what we expect from packing. + int64_t offset = 1; + for (int64_t pos : innerDimPos) { + int64_t numInsertedBefore = llvm::count_if( + innerDimPos, [&pos](int64_t pos2) { return pos > pos2; }); + res.insertPositions.push_back(pos + numInsertedBefore + offset); + } + + DenseSet posSet(res.insertPositions.begin(), + res.insertPositions.end()); + res.reassociations.reserve(packedRank); + for (int64_t i = 1; i <= packedRank; ++i) { + if (!posSet.contains(i)) { + res.reassociations.push_back(ReassociationIndices{i - 1}); + continue; + } + res.reassociations.push_back(ReassociationIndices{i - 1, i}); + ++i; + } + return res; +} + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// @@ -323,7 +392,7 @@ Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n"); + LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) @@ -354,7 +423,7 @@ // Tile the producer. int64_t resultNumber = sliceOpToTile.getSource().cast().getResultNumber(); - LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); FailureOr tiledProducer = tileableProducer.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), @@ -364,7 +433,7 @@ << "failed to tile producer op: " << *tileableProducer; return nullptr; } - LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); + LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); @@ -388,8 +457,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG( - llvm::dbgs() << "Try to fuse an extract use through block argument\n"); + LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { @@ -442,7 +510,7 @@ // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. int64_t resultNumber = pUse->get().cast().getResultNumber(); - LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. SmallVector destinationTensors; @@ -471,7 +539,7 @@ << "failed to tile producer op: " << *tileableProducer; return nullptr; } - LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); + LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); @@ -496,7 +564,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n"); + LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n"); // Gather all uses inside the containing op. SmallVector uses; @@ -530,7 +598,7 @@ assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = use->get().cast().getResultNumber(); - LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); @@ -607,8 +675,7 @@ Operation *tiled = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (tiled) { - LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n" - << *containingOp); + LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); fusedOps.push_back(tiled); continue; } @@ -617,9 +684,8 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (tiledContainingOpOperand) { - LLVM_DEBUG(llvm::dbgs() - << "\nFused an extract use through block argument\n" - << *containingOp); + LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" + << *containingOp); fusedOps.push_back(tiledContainingOpOperand); continue; } @@ -627,8 +693,7 @@ Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { - LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n" - << *containingOp); + LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp); fusedOps.push_back(cloned); continue; } @@ -697,6 +762,119 @@ return success(); } +//===----------------------------------------------------------------------===// +// LowerPackOp +//===----------------------------------------------------------------------===// + +struct LowerPackResult { + tensor::PadOp padOp; + tensor::ExpandShapeOp expandShapeOp; + linalg::GenericOp transposeOp; +}; + +/// Rewrite pack as pad + reshape + transpose. +static FailureOr rewriteLowerPack(RewriterBase &rewriter, + tensor::PackOp packOp) { + // 1. Filter out NYI cases. + if (!packOp.getOuterDimsPerm().empty()) + return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI"); + + auto packedTensorType = + packOp->getResultTypes().front().cast(); + if (!packedTensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + packOp, + "non-static shape NYI, needs a more powerful tensor.expand_shape op"); + } + + Location loc = packOp->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(packOp); + + // 2. Compute the permutation vector to move the last `numPackedDims` into the + // `innerPosDims` of a shape of rank `packedRank`. + int64_t numPackedDims = packOp.getInnerDimsPos().size(); + int64_t packedRank = packedTensorType.getRank(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = computePackingMetadata( + packedTensorType.getRank(), packOp.getInnerDimsPos()); + SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( + packedRank, lastDims, packingMetadata.insertPositions); + + // 3. Compute the stripMinedShape: this is the packed shape before any outer + // or inner permutations have been applied. + SmallVector stripMinedShape(packedTensorType.getShape()); + applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + + // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. + RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), + packingMetadata.reassociations); + assert(packOp.getPaddingValue() && "TODO: create a constant 0 op"); + auto padOp = tensor::createPadHighOp(collapsed, packOp.getSource(), + packOp.getPaddingValue(), + /*nofold=*/false, loc, rewriter); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, + DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), + DBGS() << "packedShape: "); + DBGSNL(); + llvm::interleaveComma(lastDimsToInsertPositionsPerm, + DBGS() << "lastDimsToInsertPositionsPerm: "); + DBGSNL(); llvm::interleaveComma( + packingMetadata.reassociations, DBGS() << "reassociations: ", + [&](ReassociationIndices ri) { + llvm::interleaveComma(ri, llvm::dbgs() << "|"); + }); + DBGSNL(); + llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); + + // 5. Expand from the padded result to the stripMinedShape. + auto reshapeOp = rewriter.create( + loc, + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), + padOp.getResult(), packingMetadata.reassociations); + + // 6. Transpose stripMinedShape to packedShape. + SmallVector insertPositionsToLastDimsPerm = computePermutationVector( + packedRank, packingMetadata.insertPositions, lastDims); + auto transposeOp = + makeTransposeOp(rewriter, loc, reshapeOp.getResult(), packOp.getDest(), + insertPositionsToLastDimsPerm); + + LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); + DBGS() << "reshape op: " << reshapeOp; DBGSNL(); + llvm::interleaveComma(insertPositionsToLastDimsPerm, + DBGS() << "insertPositionsToLastDimsPerm: "); + DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); + + // 7. Replace packOp by transposeOp. + rewriter.replaceOp(packOp, transposeOp->getResults()); + + return LowerPackResult{padOp, reshapeOp, transposeOp}; +} + +DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( + tensor::PackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr res = rewriteLowerPack(rewriter, target); + if (failed(res)) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + diag << "cannot lower to pad + expand + transpose"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + transformResults.push_back(res->padOp); + transformResults.push_back(res->expandShapeOp); + transformResults.push_back(res->transposeOp); + return DiagnosedSilenceableFailure::success(); +} + //===---------------------------------------------------------------------===// // MatchOp //===---------------------------------------------------------------------===// @@ -931,12 +1109,12 @@ IRRewriter rewriter(linalgOp->getContext()); rewriter.setInsertionPoint(linalgOp); - FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); + FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) return emitDefiniteFailure("data tiling failed"); transformResults.set(getPackedOp().cast(), - maybeResult->getOperation()); + maybeResult->packedLinalgOp.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1045,35 +1223,8 @@ return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()}; } -/// Return a permutation vector of size permSize that would result in moving -/// positions into desiredPositions. -/// -/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} -/// would result in a {4, 2, 0, 1, 3} permutation vector. -static SmallVector -computePermutationVector(int64_t permSize, ArrayRef positions, - ArrayRef desiredPositions) { - SmallVector res(permSize, -1); - DenseSet seen; - for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) { - res[desiredPos] = pos; - seen.insert(pos); - } - int64_t nextPos = 0; - for (int64_t &entry : res) { - if (entry != -1) - continue; - while (seen.contains(nextPos)) - ++nextPos; - entry = nextPos; - ++nextPos; - } - return res; -} - -/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) -/// where m and n are proper parallel dimensions and k is a proper reduction -/// dimension. +/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m and +/// n are proper parallel dimensions and k is a proper reduction dimension. /// Packing occurs by rewriting the op as a linalg.generic and calling /// linalg::pack by `mnkPackedSizes`. /// The order of the packed dimensions is customizable: the `mnkOrder` is a @@ -1081,7 +1232,7 @@ /// forms. /// The outer dimensions of the operands are not permuted at this time, this is /// left for future work. -static FailureOr +static FailureOr packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef mnkPackedSizes, ArrayRef mnkOrder) { @@ -1182,13 +1333,13 @@ rewriter.setInsertionPointAfter(linalgOp); // Failing to pack greedily is perfectly fine. // In the future we will want to order packings according to some metric. - FailureOr gemm = packGemmGreedily( + FailureOr packResult = packGemmGreedily( /*rewriter=*/rewriter, /*linalgOp=*/linalgOp, /*mnkPackedSizes=*/getMixedGemmPackedSizes(), /*mnkOrder=*/getGemmInnerDimsOrder()); - if (succeeded(gemm)) { - results.push_back(*gemm); + if (succeeded(packResult)) { + results.push_back(packResult->packedLinalgOp); continue; } results.push_back(linalgOp); @@ -1235,9 +1386,9 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; } // namespace -/// Return true if `permutation` is a valid permutation of the `outer_dims_perm` -/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner) -/// of the `tensor.pack` or `tensor.unpack` `op. +/// Return true if `permutation` is a valid permutation of the +/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos` +/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op. /// This is the case when the `permutation` rank matches the rank expected by /// `op` and `permutation` is itself a permutation vector. /// Return true if either `op` or `permutation` are empty to allow a simpler @@ -1281,10 +1432,10 @@ // Step 2. Bunch of runtime sanity check and error messages. // Step 2.1. Fail on multi-op handles. if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { - return emitSilenceableError() - << "requires target to map to exactly 1 packing op and 1 packed op (" - << "got " << packOrUnpackOps.size() << " and " << linalgOps.size() - << ")"; + return emitSilenceableError() << "requires target to map to exactly 1 " + "packing op and 1 packed op (" + << "got " << packOrUnpackOps.size() << " and " + << linalgOps.size() << ")"; } // Step 2.2. Fail on wrong type. @@ -1311,7 +1462,8 @@ return emitSilenceableError() << errorMsg; } - // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp. + // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical + // PackOp. if (unPackOp) { assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( @@ -1700,9 +1852,9 @@ } if (second.size() != first.size() && !second.empty()) { - auto diag = - emitSilenceableError() - << "splitting does not produce the second part for a subset of targets"; + auto diag = emitSilenceableError() + << "splitting does not produce the second part for a subset " + "of targets"; diag.attachNote() << "expected splitting to produce the second part of all " "or none of the targets"; diag.attachNote(noSecondPart->getLoc()) @@ -1965,7 +2117,8 @@ Value target, ArrayRef mixedTileSizes, ArrayRef interchange) { - // Loop types are automaticaly splat by the callee, setting up one is enough. + // Loop types are automaticaly splat by the callee, setting up one is + // enough. SmallVector loopTypes(1, builder.getType()); build(builder, result, loopTypes, target, mixedTileSizes, interchange); } @@ -1978,8 +2131,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); unsigned numExpectedLoops = staticTileSizes.size() - llvm::count(staticTileSizes, 0); @@ -2247,8 +2400,8 @@ SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); @@ -2284,8 +2437,8 @@ dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, staticNumThreads); // Call the default builder which sets up the proper operands segment sizes - // attributes for multiple variadic operands. In the absence of this, horrible - // bugs ensue. + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); @@ -2415,8 +2568,8 @@ return emitOpError( "tile_sizes and packed_tile_sizes are mutually exclusive"); if (numThreadsSpec == 0 && tileSizesSpec == 0) - return emitOpError( - "either (packed_)num_threads or (packed_)tile_sizes must be specified"); + return emitOpError("either (packed_)num_threads or (packed_)tile_sizes " + "must be specified"); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1053,9 +1053,9 @@ /// Implement packing of a single LinalgOp by performing packing by /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. /// Return the packed Linalg op on success, failure otherwise. -FailureOr linalg::pack(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - ArrayRef packedSizes) { +FailureOr linalg::pack(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + ArrayRef packedSizes) { if (packedSizes.size() != linalgOp.getNumLoops()) { return rewriter.notifyMatchFailure(linalgOp, "incorrect number of pack sizes"); @@ -1070,6 +1070,8 @@ llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();); + SmallVector packOps; + SmallVector unPackOps; // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. PackedOperandsDimList listOfPackedOperandsDim; for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { @@ -1125,8 +1127,9 @@ Attribute zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); Value zero = rewriter.create(loc, zeroAttr); - inputsAndInits.push_back(rewriter.create( + packOps.push_back(rewriter.create( loc, operand, dest, innerPos, innerPackSizes, zero)); + inputsAndInits.push_back(packOps.back()); } } @@ -1150,16 +1153,19 @@ continue; } // Build the symmetrical UnPackOp to the existing PackOp. - results.push_back(rewriter.create( + unPackOps.push_back(rewriter.create( packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); + results.push_back(unPackOps.back()); } // Step 5. Replace `linalgOp`. rewriter.replaceOp(linalgOp, results); // Return packedLinalgOp. - return cast(packedLinalgOp.getOperation()); + return PackResult{packOps, + cast(packedLinalgOp.getOperation()), + unPackOps}; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1272,11 +1272,18 @@ getReassociationIndices()); } +RankedTensorType CollapseShapeOp::inferCollapsedType( + RankedTensorType type, SmallVector reassociation) { + return inferCollapsedType( + type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + type.getContext(), reassociation))); +} + /// Compute the RankedTensorType obtained by applying `reassociation` to /// `type`. -static RankedTensorType -computeTensorReshapeCollapsedType(RankedTensorType type, - ArrayRef reassociation) { +RankedTensorType +CollapseShapeOp::inferCollapsedType(RankedTensorType type, + ArrayRef reassociation) { auto shape = type.getShape(); SmallVector newShape; newShape.reserve(reassociation.size()); @@ -1304,7 +1311,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { - auto resultType = computeTensorReshapeCollapsedType( + auto resultType = inferCollapsedType( src.getType().cast(), getSymbolLessAffineMaps( convertReassociationIndicesToExprs(b.getContext(), reassociation))); @@ -1336,7 +1343,7 @@ auto maps = op.getReassociationMaps(); RankedTensorType expectedType = - computeTensorReshapeCollapsedType(expandedType, maps); + CollapseShapeOp::inferCollapsedType(expandedType, maps); if (!isSameTypesWithoutEncoding(collapsedType, expectedType)) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; @@ -1436,7 +1443,7 @@ RankedTensorType srcType = castOp.getSource().getType().cast(); - RankedTensorType newResultType = computeTensorReshapeCollapsedType( + RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType( srcType, collapseShapeOp.getReassociationMaps()); if (newResultType == collapseShapeOp.getResultType()) { diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s + +func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // First tensor.pack is lowered to tensor.pad + tensor.expand_shape + transpose (linalg.generic) + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> + // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32> + // CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] + // CHECK-SAME: ins(%{{.*}} : tensor<17x8x2x32x16x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<17x2x16x16x32x8xf32>) + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32> + return %pack : tensor<17x2x16x16x32x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.generic">) +} + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8375,11 +8375,13 @@ ":LinalgDialect", ":LinalgTransformOpsIncGen", ":LinalgTransforms", + ":LinalgUtils", ":PDLDialect", ":Parser", ":SCFTransforms", ":SideEffectInterfaces", ":TensorDialect", + ":TensorUtils", ":TilingInterface", ":TransformDialect", ":TransformDialectUtils",