diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -527,15 +527,20 @@ rewriter.getIndexAttr(0)); for (auto [pos, innerSize] : llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { + int outerPos = + packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; OpFoldResult origSize = rewriter.createOrFold( loc, packOp.getSource(), rewriter.create(loc, pos)); - AffineExpr s0, d0; - bindDims(rewriter.getContext(), d0); + OpFoldResult outerSize = rewriter.createOrFold( + loc, packOp.getDest(), + rewriter.create(loc, outerPos)); + AffineExpr s0, d0, d1; + bindDims(rewriter.getContext(), d0, d1); bindSymbols(rewriter.getContext(), s0); - auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0); - highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map, - {origSize, innerSize}); + auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); + highs[pos] = affine::makeComposedFoldedAffineApply( + rewriter, loc, map, {outerSize, origSize, innerSize}); } RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -212,6 +212,36 @@ // ----- +// CHECK-LABEL: func.func @pack_with_pad( +func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>) + -> tensor<265x16x16x1xf32> { + // CHECK: tensor.pad {{.*}} low[0, 0] + // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32> + // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]] + // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32> + // CHECK: linalg.transpose + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: permutation = [0, 2, 1, 3] + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.pack %src + padding_value(%cst : f32) + inner_dims_pos = [0, 1] + inner_tiles = [16, 1] into %dest + : tensor<4225x12xf32> -> tensor<265x16x16x1xf32> + return %0 : tensor<265x16x16x1xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) +} + +// ----- + // CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm( func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>, %dest: tensor<200x4x16x100x16x32xi32>) @@ -244,8 +274,8 @@ // ----- -// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 32 - s1)> // CHECK: func.func @dynamic_pack_pad_transpose_inner_and_outer_dims( // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor) -> tensor { @@ -258,8 +288,10 @@ // CHECK-DAG: %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index // CHECK-DAG: %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor - // CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] - // CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]] + // CHECK-DAG: %[[DEST_D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] + // CHECK-DAG: %[[DEST_D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] + // CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[DEST_D0]], %[[D1]]] + // CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[DEST_D1]], %[[D0]]] // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]] // CHECK: : tensor to tensor // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]