diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -321,28 +321,36 @@ DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); if (packOp.isLikePad()) { - // This pack is just a plain pad. - // Just insert the pad in the higher ranked tensor. - auto emptyOp = - rewriter.create(loc, packedTensorType, ValueRange{}); - // Offsets. - SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); - // Strides. - SmallVector ones(packedRank, rewriter.getIndexAttr(1)); - SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, packOp.getDest()); - - auto insertSliceOp = rewriter.create( - loc, /*source=*/padOp, /*dest=*/emptyOp, - /*offsets=*/zeros, sizes, - /*strides=*/ones); - - LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); - - rewriter.replaceOp(packOp, insertSliceOp->getResults()); - - return LowerPackResult{padOp, /*reshapeOp=*/nullptr, - /*transposeOp=*/nullptr}; + // Pack ops which operate as simple pads may not produce legal + // tensor.insert_slice operations when the packed type does not rank reduce + // tot the padded type. + SliceVerificationResult rankReduces = + isRankReducedType(packedTensorType, padOp.getResultType()); + + if (rankReduces == SliceVerificationResult::Success) { + // This pack is just a plain pad. + // Just insert the pad in the higher ranked tensor. + auto emptyOp = + rewriter.create(loc, packedTensorType, ValueRange{}); + // Offsets. + SmallVector zeros(packedRank, rewriter.getIndexAttr(0)); + // Strides. + SmallVector ones(packedRank, rewriter.getIndexAttr(1)); + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, packOp.getDest()); + + auto insertSliceOp = rewriter.create( + loc, /*source=*/padOp, /*dest=*/emptyOp, + /*offsets=*/zeros, sizes, + /*strides=*/ones); + + LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); + + rewriter.replaceOp(packOp, insertSliceOp->getResults()); + + return LowerPackResult{padOp, /*reshapeOp=*/nullptr, + /*transposeOp=*/nullptr}; + } } // 5. Expand from the padded result to the stripMinedShape. auto reshapeOp = rewriter.create( diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -356,6 +356,40 @@ return %pack : tensor<1x1x1x1x136x64x16x16xf32> } +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- + +// CHECK-LABEL: func.func @pack_as_pad_with_unit_dims( +// CHECK: %[[SRC:.+]]: tensor<3x1x1x1xf32>, +// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x8x1xf32>) +func.func @pack_as_pad_with_unit_dims(%arg0: tensor<3x1x1x1xf32>, %arg1: tensor<1x1x1x1x8x1xf32>) -> (tensor<1x1x1x1x8x1xf32>) { + %zero = arith.constant 0.0 : f32 + + // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[5, 0, 0, 0] { + // CHECK: : tensor<3x1x1x1xf32> to tensor<8x1x1x1xf32> + // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] [{{.*}}[0, 1], [2, 3], [4], [5]] + // CHECK-SAME: tensor<8x1x1x1xf32> into tensor<1x8x1x1x1x1xf32> + // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose + // CHECK-SAME: ins(%[[EXPAND]] : tensor<1x8x1x1x1x1xf32>) + // CHECK-SAME: outs(%[[OUT]] : tensor<1x1x1x1x8x1xf32>) + // CHECK-SAME: permutation = [0, 2, 4, 5, 1, 3] + // CHECK: return %[[TRANSPOSED]] : tensor<1x1x1x1x8x1xf32> + %pack = tensor.pack %arg0 + padding_value(%zero : f32) + inner_dims_pos = [0, 1] + inner_tiles = [8, 1] into %arg1 : tensor<3x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32> + + return %pack : tensor<1x1x1x1x8x1xf32> +} + + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): %pack = transform.structured.match ops{["tensor.pack"]} in %module_op