diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -907,6 +907,27 @@ GenericOp genericOp, ArrayRef foldedIterationDims, RewriterBase &rewriter); +struct LowerPackResult { + tensor::PadOp padOp; + tensor::ExpandShapeOp expandShapeOp; + linalg::TransposeOp transposeOp; +}; + +/// Rewrite pack as pad + reshape + transpose. +FailureOr lowerPack(RewriterBase &rewriter, + tensor::PackOp packOp); + +struct LowerUnPackOpResult { + tensor::EmptyOp emptyOp; + linalg::TransposeOp transposeOp; + tensor::CollapseShapeOp collapseShapeOp; + tensor::ExtractSliceOp extractSliceOp; +}; + +/// Rewrite pack as empty + transpose + reshape + extract_slice. +FailureOr lowerUnPack(RewriterBase &rewriter, + tensor::UnPackOp unPackOp); + /// Struct to hold the result of a `pack` call. struct PackResult { SmallVector packOps; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -748,126 +748,6 @@ // LowerPackOp //===----------------------------------------------------------------------===// -struct LowerPackResult { - tensor::PadOp padOp; - tensor::ExpandShapeOp expandShapeOp; - linalg::TransposeOp transposeOp; -}; - -/// Rewrite pack as pad + reshape + transpose. -static FailureOr lowerPack(RewriterBase &rewriter, - tensor::PackOp packOp) { - // 1. Filter out NYI cases. - if (!packOp.getOuterDimsPerm().empty()) - return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI"); - - auto packedTensorType = - packOp->getResultTypes().front().cast(); - if (!packedTensorType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - packOp, - "non-static shape NYI, needs a more powerful tensor.expand_shape op"); - } - - Location loc = packOp->getLoc(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(packOp); - - // 2. Compute the permutation vector to move the last `numPackedDims` into the - // `innerPosDims` of a shape of rank `packedRank`. - int64_t numPackedDims = packOp.getInnerDimsPos().size(); - int64_t packedRank = packedTensorType.getRank(); - auto lastDims = llvm::to_vector( - llvm::seq(packedRank - numPackedDims, packedRank)); - PackingMetadata packingMetadata = computePackingMetadata( - packedTensorType.getRank(), packOp.getInnerDimsPos()); - SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( - packedRank, lastDims, packingMetadata.insertPositions); - - // 3. Compute the stripMinedShape: this is the packed shape before any outer - // or inner permutations have been applied. - SmallVector stripMinedShape(packedTensorType.getShape()); - applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); - - // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. - RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - packingMetadata.reassociations); - Value paddingValue = packOp.getPaddingValue(); - if (!paddingValue) { - paddingValue = rewriter.create( - loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); - } - auto padOp = - tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue, - /*nofold=*/false, loc, rewriter); - - LLVM_DEBUG( - DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, - DBGS() << "insertPositions: "); - DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), - DBGS() << "packedShape: "); - DBGSNL(); - llvm::interleaveComma(lastDimsToInsertPositionsPerm, - DBGS() << "lastDimsToInsertPositionsPerm: "); - DBGSNL(); llvm::interleaveComma( - packingMetadata.reassociations, DBGS() << "reassociations: ", - [&](ReassociationIndices ri) { - llvm::interleaveComma(ri, llvm::dbgs() << "|"); - }); - DBGSNL(); - llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); - DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); - - if (packOp.isLikePad()) { - // This pack is just a plain pad. - // Just insert the pad in the higher ranked tensor. - auto emptyOp = - rewriter.create(loc, packedTensorType, ValueRange{}); - // Offsets. - SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); - // Strides. - SmallVector ones(packedRank, rewriter.getIndexAttr(1)); - SmallVector sizes = - getMixedDimensions(rewriter, loc, packOp.getDest()); - - auto insertSliceOp = rewriter.create( - loc, /*source=*/padOp, /*dest=*/emptyOp, - /*offsets=*/zeros, sizes, - /*strides=*/ones); - - LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); - - rewriter.replaceOp(packOp, insertSliceOp->getResults()); - - return LowerPackResult{padOp, /*reshapeOp=*/nullptr, - /*transposeOp=*/nullptr}; - } - // 5. Expand from the padded result to the stripMinedShape. - auto reshapeOp = rewriter.create( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); - - // 6. Transpose stripMinedShape to packedShape. - SmallVector insertPositionsToLastDimsPerm = computePermutationVector( - packedRank, packingMetadata.insertPositions, lastDims); - auto transposeOp = rewriter.create( - loc, reshapeOp.getResult(), packOp.getDest(), - insertPositionsToLastDimsPerm); - - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - llvm::interleaveComma(insertPositionsToLastDimsPerm, - DBGS() << "insertPositionsToLastDimsPerm: "); - DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); - - // 7. Replace packOp by transposeOp. - rewriter.replaceOp(packOp, transposeOp->getResults()); - - return LowerPackResult{padOp, reshapeOp, transposeOp}; -} - DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( tensor::PackOp target, transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { @@ -889,115 +769,6 @@ // LowerUnPackOp //===----------------------------------------------------------------------===// -struct LowerUnPackOpResult { - tensor::EmptyOp emptyOp; - linalg::TransposeOp transposeOp; - tensor::CollapseShapeOp collapseShapeOp; - tensor::ExtractSliceOp extractSliceOp; -}; - -/// Rewrite pack as empty + transpose + reshape + extract_slice. -static FailureOr lowerUnPack(RewriterBase &rewriter, - tensor::UnPackOp unPackOp) { - // 1. Filter out NYI cases. - if (!unPackOp.getOuterDimsPerm().empty()) - return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI"); - - RankedTensorType packedTensorType = unPackOp.getSourceType(); - if (!packedTensorType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - unPackOp, - "non-static shape NYI, needs a more powerful tensor.expand_shape op"); - } - - Location loc = unPackOp->getLoc(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(unPackOp); - - int64_t packedRank = packedTensorType.getRank(); - - OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); - auto destTensorType = unPackOp.getDest().getType().cast(); - if (unPackOp.isLikeUnPad()) { - // This unpack is just a plain unpad. - // Just extract the slice from the higher ranked tensor. - ArrayRef destShape = destTensorType.getShape(); - // The inner dimensions stay the same as the destination tensor, but the - // outer ones are additional 1s. - SmallVector sizes(packedRank - destShape.size(), one); - sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest())); - - auto extractSliceOp = rewriter.create( - loc, destTensorType, unPackOp.getSource(), - SmallVector(packedRank, zero), sizes, - SmallVector(packedRank, one)); - - rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); - - return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, - /*reshapeOp=*/nullptr, extractSliceOp}; - } - // 2. Compute the permutation vector to move the last `numPackedDims` into - // the `innerPosDims` of a shape of rank `packedRank`. - int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); - auto lastDims = llvm::to_vector( - llvm::seq(packedRank - numPackedDims, packedRank)); - PackingMetadata packingMetadata = - computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); - SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( - packedRank, lastDims, packingMetadata.insertPositions); - - // 3. Compute the stripMinedShape: this is the packed shape without outer and - // inner permutations. - SmallVector stripMinedShape(packedTensorType.getShape()); - applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); - - // 4. Transpose packedShape to stripMinedShape. - RankedTensorType stripMinedTensorType = - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); - RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( - stripMinedTensorType, packingMetadata.reassociations); - auto emptyOp = - rewriter.create(loc, stripMinedTensorType, ValueRange{}); - auto transposeOp = rewriter.create( - loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); - - LLVM_DEBUG( - DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, - DBGS() << "insertPositions: "); - DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), - DBGS() << "packedShape: "); - DBGSNL(); - llvm::interleaveComma(lastDimsToInsertPositionsPerm, - DBGS() << "lastDimsToInsertPositionsPerm: "); - DBGSNL(); llvm::interleaveComma( - packingMetadata.reassociations, DBGS() << "reassociations: ", - [&](ReassociationIndices ri) { - llvm::interleaveComma(ri, llvm::dbgs() << "|"); - }); - DBGSNL(); - llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); - DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); - - // 5. Collapse from the stripMinedShape to the padded result. - auto reshapeOp = rewriter.create( - loc, collapsedType, transposeOp->getResult(0), - packingMetadata.reassociations); - - // 6. ExtractSlice - int64_t destRank = destTensorType.getRank(); - auto extractSliceOp = rewriter.create( - loc, destTensorType, reshapeOp->getResult(0), - SmallVector(destRank, zero), - tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), - SmallVector(destRank, one)); - - // 7. Replace unPackOp by extractSliceOp. - rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); - - return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; -} - DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -477,6 +477,220 @@ } // namespace +FailureOr linalg::lowerPack(RewriterBase &rewriter, + tensor::PackOp packOp) { + // 1. Filter out NYI cases. + if (!packOp.getOuterDimsPerm().empty()) + return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI"); + + auto packedTensorType = + packOp->getResultTypes().front().cast(); + if (!packedTensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + packOp, + "non-static shape NYI, needs a more powerful tensor.expand_shape op"); + } + + Location loc = packOp->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(packOp); + + // 2. Compute the permutation vector to move the last `numPackedDims` into the + // `innerPosDims` of a shape of rank `packedRank`. + int64_t numPackedDims = packOp.getInnerDimsPos().size(); + int64_t packedRank = packedTensorType.getRank(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = computePackingMetadata( + packedTensorType.getRank(), packOp.getInnerDimsPos()); + SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( + packedRank, lastDims, packingMetadata.insertPositions); + + // 3. Compute the stripMinedShape: this is the packed shape before any outer + // or inner permutations have been applied. + SmallVector stripMinedShape(packedTensorType.getShape()); + applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + + // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. + RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), + packingMetadata.reassociations); + Value paddingValue = packOp.getPaddingValue(); + if (!paddingValue) { + paddingValue = rewriter.create( + loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); + } + auto padOp = + tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue, + /*nofold=*/false, loc, rewriter); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, + DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), + DBGS() << "packedShape: "); + DBGSNL(); + llvm::interleaveComma(lastDimsToInsertPositionsPerm, + DBGS() << "lastDimsToInsertPositionsPerm: "); + DBGSNL(); llvm::interleaveComma( + packingMetadata.reassociations, DBGS() << "reassociations: ", + [&](ReassociationIndices ri) { + llvm::interleaveComma(ri, llvm::dbgs() << "|"); + }); + DBGSNL(); + llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); + + if (packOp.isLikePad()) { + // This pack is just a plain pad. + // Just insert the pad in the higher ranked tensor. + auto emptyOp = + rewriter.create(loc, packedTensorType, ValueRange{}); + // Offsets. + SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); + // Strides. + SmallVector ones(packedRank, rewriter.getIndexAttr(1)); + SmallVector sizes = + getMixedDimensions(rewriter, loc, packOp.getDest()); + + auto insertSliceOp = rewriter.create( + loc, /*source=*/padOp, /*dest=*/emptyOp, + /*offsets=*/zeros, sizes, + /*strides=*/ones); + + LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); + + rewriter.replaceOp(packOp, insertSliceOp->getResults()); + + return LowerPackResult{padOp, /*reshapeOp=*/nullptr, + /*transposeOp=*/nullptr}; + } + // 5. Expand from the padded result to the stripMinedShape. + auto reshapeOp = rewriter.create( + loc, + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), + padOp.getResult(), packingMetadata.reassociations); + + // 6. Transpose stripMinedShape to packedShape. + SmallVector insertPositionsToLastDimsPerm = computePermutationVector( + packedRank, packingMetadata.insertPositions, lastDims); + auto transposeOp = rewriter.create( + loc, reshapeOp.getResult(), packOp.getDest(), + insertPositionsToLastDimsPerm); + + LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); + DBGS() << "reshape op: " << reshapeOp; DBGSNL(); + llvm::interleaveComma(insertPositionsToLastDimsPerm, + DBGS() << "insertPositionsToLastDimsPerm: "); + DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); + + // 7. Replace packOp by transposeOp. + rewriter.replaceOp(packOp, transposeOp->getResults()); + + return LowerPackResult{padOp, reshapeOp, transposeOp}; +} + +FailureOr linalg::lowerUnPack(RewriterBase &rewriter, + tensor::UnPackOp unPackOp) { + // 1. Filter out NYI cases. + if (!unPackOp.getOuterDimsPerm().empty()) + return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI"); + + RankedTensorType packedTensorType = unPackOp.getSourceType(); + if (!packedTensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + unPackOp, + "non-static shape NYI, needs a more powerful tensor.expand_shape op"); + } + + Location loc = unPackOp->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(unPackOp); + + int64_t packedRank = packedTensorType.getRank(); + + OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); + auto destTensorType = unPackOp.getDest().getType().cast(); + if (unPackOp.isLikeUnPad()) { + // This unpack is just a plain unpad. + // Just extract the slice from the higher ranked tensor. + ArrayRef destShape = destTensorType.getShape(); + // The inner dimensions stay the same as the destination tensor, but the + // outer ones are additional 1s. + SmallVector sizes(packedRank - destShape.size(), one); + sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest())); + + auto extractSliceOp = rewriter.create( + loc, destTensorType, unPackOp.getSource(), + SmallVector(packedRank, zero), sizes, + SmallVector(packedRank, one)); + + rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); + + return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, + /*reshapeOp=*/nullptr, extractSliceOp}; + } + // 2. Compute the permutation vector to move the last `numPackedDims` into + // the `innerPosDims` of a shape of rank `packedRank`. + int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = + computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); + SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( + packedRank, lastDims, packingMetadata.insertPositions); + + // 3. Compute the stripMinedShape: this is the packed shape without outer and + // inner permutations. + SmallVector stripMinedShape(packedTensorType.getShape()); + applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + + // 4. Transpose packedShape to stripMinedShape. + RankedTensorType stripMinedTensorType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( + stripMinedTensorType, packingMetadata.reassociations); + auto emptyOp = + rewriter.create(loc, stripMinedTensorType, ValueRange{}); + auto transposeOp = rewriter.create( + loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, + DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), + DBGS() << "packedShape: "); + DBGSNL(); + llvm::interleaveComma(lastDimsToInsertPositionsPerm, + DBGS() << "lastDimsToInsertPositionsPerm: "); + DBGSNL(); llvm::interleaveComma( + packingMetadata.reassociations, DBGS() << "reassociations: ", + [&](ReassociationIndices ri) { + llvm::interleaveComma(ri, llvm::dbgs() << "|"); + }); + DBGSNL(); + llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); + + // 5. Collapse from the stripMinedShape to the padded result. + auto reshapeOp = rewriter.create( + loc, collapsedType, transposeOp->getResult(0), + packingMetadata.reassociations); + + // 6. ExtractSlice + int64_t destRank = destTensorType.getRank(); + auto extractSliceOp = rewriter.create( + loc, destTensorType, reshapeOp->getResult(0), + SmallVector(destRank, zero), + tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + SmallVector(destRank, one)); + + // 7. Replace unPackOp by extractSliceOp. + rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); + + return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; +} + SmallVector PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { SmallVector res;