diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -25,6 +25,7 @@ namespace tensor { class PackOp; +class UnPackOp; } // namespace tensor namespace transform { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -229,7 +229,7 @@ #### Return modes This operation ignores non-pack ops and drops them in the return. - This operation produces a silenceableFailure if the padding fails for any + This operation produces a silenceableFailure if the rewrite fails for any reason. If all the operations referred to by the `target` are rewritten, the transform succeeds. @@ -252,6 +252,45 @@ }]; } +//===----------------------------------------------------------------------===// +// LowerUnPackOp +//===----------------------------------------------------------------------===// +def LowerUnPackOp : Op { + let description = [{ + Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape + + tensor.extract_slice. + + #### Return modes + + This operation ignores non-unpack ops and drops them in the return. + This operation produces a silenceableFailure if the rewrite fails for any + reason. + If all the operations referred to by the `target` are rewritten, the + transform succeeds. + Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops. + }]; + + let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target); + let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op, + Transform_ConcreteOpType<"linalg.transpose">:$transpose_op, + Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op, + Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op); + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::tensor::UnPackOp target, + ::mlir::transform::ApplyToEachResultList &transformResults, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // MatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -782,8 +782,8 @@ }; /// Rewrite pack as pad + reshape + transpose. -static FailureOr rewriteLowerPack(RewriterBase &rewriter, - tensor::PackOp packOp) { +static FailureOr lowerPack(RewriterBase &rewriter, + tensor::PackOp packOp) { // 1. Filter out NYI cases. if (!packOp.getOuterDimsPerm().empty()) return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI"); @@ -822,7 +822,7 @@ packingMetadata.reassociations); Value paddingValue = packOp.getPaddingValue(); if (!paddingValue) { - paddingValue = rewriter.create( + rewriter.create( loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = @@ -876,7 +876,7 @@ transform::TransformState &state) { IRRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); - FailureOr res = rewriteLowerPack(rewriter, target); + FailureOr res = lowerPack(rewriter, target); if (failed(res)) { Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); diag << "cannot lower to pad + expand + transpose"; @@ -888,6 +888,117 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// LowerUnPackOp +//===----------------------------------------------------------------------===// + +struct LowerUnPackOpResult { + tensor::EmptyOp emptyOp; + linalg::TransposeOp transposeOp; + tensor::CollapseShapeOp collapseShapeOp; + tensor::ExtractSliceOp extractSliceOp; +}; + +/// Rewrite pack as empty + transpose + reshape + extract_slice. +static FailureOr lowerUnPack(RewriterBase &rewriter, + tensor::UnPackOp unPackOp) { + // 1. Filter out NYI cases. + if (!unPackOp.getOuterDimsPerm().empty()) + return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI"); + + RankedTensorType packedTensorType = unPackOp.getSourceType(); + if (!packedTensorType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + unPackOp, + "non-static shape NYI, needs a more powerful tensor.expand_shape op"); + } + + Location loc = unPackOp->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(unPackOp); + + // 2. Compute the permutation vector to move the last `numPackedDims` into the + // `innerPosDims` of a shape of rank `packedRank`. + int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); + int64_t packedRank = packedTensorType.getRank(); + auto lastDims = llvm::to_vector( + llvm::seq(packedRank - numPackedDims, packedRank)); + PackingMetadata packingMetadata = + computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); + SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( + packedRank, lastDims, packingMetadata.insertPositions); + + // 3. Compute the stripMinedShape: this is the packed shape without outer and + // inner permutations. + SmallVector stripMinedShape(packedTensorType.getShape()); + applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + + // 4. Transpose packedShape to stripMinedShape. + RankedTensorType stripMinedTensorType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( + stripMinedTensorType, packingMetadata.reassociations); + auto emptyOp = + rewriter.create(loc, stripMinedTensorType, ValueRange{}); + auto transposeOp = rewriter.create( + loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); + + LLVM_DEBUG( + DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, + DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), + DBGS() << "packedShape: "); + DBGSNL(); + llvm::interleaveComma(lastDimsToInsertPositionsPerm, + DBGS() << "lastDimsToInsertPositionsPerm: "); + DBGSNL(); llvm::interleaveComma( + packingMetadata.reassociations, DBGS() << "reassociations: ", + [&](ReassociationIndices ri) { + llvm::interleaveComma(ri, llvm::dbgs() << "|"); + }); + DBGSNL(); + llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); + + // 5. Collapse from the stripMinedShape to the padded result. + auto reshapeOp = rewriter.create( + loc, collapsedType, transposeOp->getResult(0), + packingMetadata.reassociations); + + // 6. ExtractSlice + auto destTensorType = unPackOp.getDest().getType().cast(); + int64_t destRank = destTensorType.getRank(); + OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); + auto extractSliceOp = rewriter.create( + loc, destTensorType, reshapeOp->getResult(0), + SmallVector(destRank, zero), + tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + SmallVector(destRank, one)); + + // 7. Replace unPackOp by transposeOp. + rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); + + return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; +} + +DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( + tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + FailureOr res = lowerUnPack(rewriter, target); + if (failed(res)) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + diag << "cannot rewrite to pad + expand + transpose"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + transformResults.push_back(res->emptyOp); + transformResults.push_back(res->transposeOp); + transformResults.push_back(res->collapseShapeOp); + transformResults.push_back(res->extractSliceOp); + return DiagnosedSilenceableFailure::success(); +} + //===---------------------------------------------------------------------===// // MatchOp //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s + // CHECK-LABEL: func.func @pack( func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> { %cst_0 = arith.constant 0.0 : f32 @@ -26,3 +27,33 @@ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) } +// ----- + +// CHECK-LABEL: func.func @unpack( +func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>) + // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3] + // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] + // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32> + %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 + : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32> + return %pack : tensor<129x47x16x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) +}