diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -477,7 +477,8 @@ // 1. Filter out NYI cases. auto packedTensorType = packOp->getResultTypes().front().cast(); - if (!packedTensorType.hasStaticShape()) { + if (llvm::any_of(packOp.getStaticInnerTiles(), + [](int64_t size) { return ShapedType::isDynamic(size); })) { return rewriter.notifyMatchFailure( packOp, "non-static shape NYI, needs a more powerful tensor.expand_shape op"); @@ -520,6 +521,22 @@ applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. + SmallVector lows(packOp.getSourceRank(), + rewriter.getIndexAttr(0)); + SmallVector highs(packOp.getSourceRank(), + rewriter.getIndexAttr(0)); + for (auto [pos, innerSize] : + llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { + OpFoldResult origSize = rewriter.createOrFold( + loc, packOp.getSource(), + rewriter.create(loc, pos)); + AffineExpr s0, d0; + bindDims(rewriter.getContext(), d0); + bindSymbols(rewriter.getContext(), s0); + auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0); + highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map, + {origSize, innerSize}); + } RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), packingMetadata.reassociations); @@ -529,8 +546,8 @@ loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = - tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue, - /*nofold=*/false, loc, rewriter); + rewriter.create(loc, collapsed, packOp.getSource(), lows, + highs, paddingValue, /*nofold=*/false); LLVM_DEBUG( DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -1,12 +1,11 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -cse --split-input-file | FileCheck %s // CHECK-LABEL: func.func @pack( func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]] // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32> @@ -33,8 +32,7 @@ func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> { // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]] + // CHECK: tensor.pad {{.*}} low[0, 0] // CHECK: : tensor<128x8xf32> to tensor<128x8xf32> // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]] // CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32> @@ -64,8 +62,7 @@ %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.insert_slice - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32> // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]] @@ -100,8 +97,7 @@ func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> { %cst_0 = arith.constant 0.0 : f32 - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]] // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32> @@ -190,8 +186,7 @@ func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32> // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]] // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32> @@ -221,8 +216,7 @@ func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32> // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]] // CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32> @@ -250,13 +244,64 @@ // ----- +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)> +// CHECK: func.func @dynamic_pack_pad_transpose_inner_and_outer_dims( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor) -> tensor { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]] + // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]] + // CHECK-DAG: %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index + // CHECK-DAG: %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor + // CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] + // CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]] + // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]] + // CHECK: : tensor to tensor + // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] + // CHECK-SAME: : tensor into tensor + // CHECK: %[[TRANSP:.+]] = linalg.transpose + // CHECK-SAME: ins(%[[EXPAND]] : tensor) + // CHECK-SAME: outs(%[[EMPTY]] : tensor) + // CHECK-SAME: permutation = [2, 0, 3, 1] + // CHECK: return %[[TRANSP]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %source, %c0 : tensor + %d1 = tensor.dim %source, %c1 : tensor + %padding_value = arith.constant 0.0 : f32 + + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %tiled_d0 = arith.ceildivui %d0, %c32 : index + %tiled_d1 = arith.ceildivui %d1, %c16 : index + %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor + %pack = tensor.pack %source padding_value(%padding_value : f32) + outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %init_pack + : tensor -> tensor + return %pack : tensor +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!pdl.operation) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- + // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm( func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.insert_slice - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32> // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]