diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -522,6 +522,7 @@ struct PackingMetadata { SmallVector insertPositions; + SmallVector outerPositions; SmallVector reassociations; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -480,9 +480,6 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp) { // 1. Filter out NYI cases. - if (!packOp.getOuterDimsPerm().empty()) - return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI"); - auto packedTensorType = packOp->getResultTypes().front().cast(); if (!packedTensorType.hasStaticShape()) { @@ -495,21 +492,37 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(packOp); - // 2. Compute the permutation vector to move the last `numPackedDims` into the - // `innerPosDims` of a shape of rank `packedRank`. + // 2. Compute the permutation vector to shuffle packed shape into the shape + // before any outer or inner permutations have been applied. The permutation + // can be obained from two permutations: + // a) Compute the permutation vector to move the last `numPackedDims` into + // the `innerPosDims` of a shape of rank `packedRank`. + // b) Compute the permutation vector to move outer dims if the pack op + // has outer_dims_perm. + // Apply (b) permutation on (a) permutation to get the final permutation. int64_t numPackedDims = packOp.getInnerDimsPos().size(); int64_t packedRank = packedTensorType.getRank(); auto lastDims = llvm::to_vector( llvm::seq(packedRank - numPackedDims, packedRank)); PackingMetadata packingMetadata = computePackingMetadata( packedTensorType.getRank(), packOp.getInnerDimsPos()); - SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( + SmallVector innerPositionsPerm = computePermutationVector( packedRank, lastDims, packingMetadata.insertPositions); + SmallVector outerPos = packingMetadata.outerPositions; + ArrayRef outerPerm = packOp.getOuterDimsPerm(); + if (!outerPerm.empty()) + applyPermutationToVector(outerPos, outerPerm); + SmallVector outerPositionPerm = computePermutationVector( + packedRank, packingMetadata.outerPositions, outerPos); + + SmallVector packedToStripMinedShapePerm = innerPositionsPerm; + applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm); + // 3. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. SmallVector stripMinedShape(packedTensorType.getShape()); - applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( @@ -527,11 +540,17 @@ LLVM_DEBUG( DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, DBGS() << "insertPositions: "); + DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, + DBGS() << "outerPositions: "); DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), DBGS() << "packedShape: "); DBGSNL(); - llvm::interleaveComma(lastDimsToInsertPositionsPerm, - DBGS() << "lastDimsToInsertPositionsPerm: "); + llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: "); + DBGSNL(); llvm::interleaveComma(innerPositionsPerm, + DBGS() << "innerPositionsPerm: "); + DBGSNL(); + llvm::interleaveComma(packedToStripMinedShapePerm, + DBGS() << "packedToStripMinedShapePerm: "); DBGSNL(); llvm::interleaveComma( packingMetadata.reassociations, DBGS() << "reassociations: ", [&](ReassociationIndices ri) { @@ -572,16 +591,14 @@ padOp.getResult(), packingMetadata.reassociations); // 6. Transpose stripMinedShape to packedShape. - SmallVector insertPositionsToLastDimsPerm = computePermutationVector( - packedRank, packingMetadata.insertPositions, lastDims); + SmallVector transpPerm = + invertPermutationVector(packedToStripMinedShapePerm); auto transposeOp = rewriter.create( - loc, reshapeOp.getResult(), packOp.getDest(), - insertPositionsToLastDimsPerm); + loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - llvm::interleaveComma(insertPositionsToLastDimsPerm, - DBGS() << "insertPositionsToLastDimsPerm: "); + llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); // 7. Replace packOp by transposeOp. diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -480,6 +480,7 @@ res.insertPositions.end()); res.reassociations.reserve(packedRank); for (int64_t i = 1; i <= packedRank; ++i) { + res.outerPositions.push_back(i - 1); if (!posSet.contains(i)) { res.reassociations.push_back(ReassociationIndices{i - 1}); continue; diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -183,3 +183,67 @@ !transform.op<"tensor.collapse_shape">, !transform.op<"tensor.extract_slice">) } + +// ----- + +// CHECK-LABEL: func.func @pack_with_outer_dims_perm( +func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, + %dest: tensor<200x4x16x100x16x32xi32>) + -> tensor<200x4x16x100x16x32xi32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32> + // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]] + // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>) + // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>) + // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3] + %0 = tensor.pack %src + outer_dims_perm = [1, 2, 3, 0] + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32> + return %0 : tensor<200x4x16x100x16x32xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- + +// CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm( +func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>, + %dest: tensor<200x4x16x100x16x32xi32>) + -> tensor<200x4x16x100x16x32xi32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32> + // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]] + // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{.*}} : tensor<100x200x4x32x16x16xi32>) + // CHECK-SAME: outs(%{{.*}} : tensor<200x4x16x100x16x32xi32>) + // CHECK-SAME: permutation = [1, 2, 4, 0, 5, 3] + %cst_0 = arith.constant 0 : i32 + %0 = tensor.pack %src + padding_value(%cst_0 : i32) + outer_dims_perm = [1, 2, 3, 0] + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x127x255xi32> -> tensor<200x4x16x100x16x32xi32> + return %0 : tensor<200x4x16x100x16x32xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +}