diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1823,7 +1823,9 @@ // Returns true if we have enough static information to catch undefined // behavior when the tile size does not divide perfectly the dimension of - // the input tensor. + // the input tensor. If a given dimension or a tile associated with it is + // dynamic, the dimension is not considered as we don't have enough static + // information to understand if the tile perfectly divides that dimension. static bool requirePaddingValue(ArrayRef inputShape, ArrayRef innerDimsPos, ArrayRef innerTiles); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -563,12 +563,25 @@ Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, operand, innerPackSizes, innerPos, /*outerDimsPerm=*/{}); - // TODO: value of the padding attribute should be determined by consumers. - auto zeroAttr = - rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); - Value zero = rewriter.create(loc, zeroAttr); - packOps.push_back(rewriter.create( - loc, operand, dest, innerPos, innerPackSizes, zero)); + ShapedType operandType = operand.getType().cast(); + bool areConstantTiles = + llvm::all_of(innerPackSizes, [](OpFoldResult tile) { + return getConstantIntValue(tile).has_value(); + }); + if (areConstantTiles && operandType.hasStaticShape() && + !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos, + innerPackSizes)) { + packOps.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes)); + } else { + // TODO: value of the padding attribute should be determined by + // consumers. + auto zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + packOps.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes, zero)); + } inputsAndInits.push_back(packOps.back()); } } diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir @@ -593,3 +593,36 @@ : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">) } + +// ----- + +func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>) + -> tensor<32x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>) + outs(%C: tensor<32x32xf32>) + -> tensor<32x32xf32> + return %0 : tensor<32x32xf32> +} + +// CHECK-LABEL: no_padding_on_packs +// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] +// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32> +// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 8] +// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x4x8x8xf32> +// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] +// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32> + +transform.sequence failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.pack %0 packed_sizes = [4, 8, 8] + : (!transform.any_op) -> (!transform.op<"linalg.generic">) + %pack = transform.get_producer_of_operand %1[1] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">) + %2, %pack_2, %empty_unpack_2 = + transform.structured.pack_transpose %pack with_compute_op(%1) + outer_perm = [1, 0] inner_perm = [1, 0] + : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op) +} diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir --- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -348,3 +348,37 @@ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0] : (!transform.op<"linalg.matvec">) -> !transform.any_op } + +// ----- + +func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>) + -> tensor<32x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>) + outs(%C: tensor<32x32xf32>) + -> tensor<32x32xf32> + return %0 : tensor<32x32xf32> +} + +// CHECK-LABEL: no_padding_on_packs +// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 4] +// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x8x8x4xf32> +// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 16] into %{{.+}} : tensor<32x32xf32> -> tensor<2x8x4x16xf32> +// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 16] +// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x2x8x16xf32> + +transform.sequence failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.op<"linalg.matmul"> + %1 = transform.structured.pack_greedily %0 + matmul_packed_sizes = [8, 16, 4] matmul_inner_dims_order = [0, 1, 2] + : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic"> + %pack = transform.get_producer_of_operand %1[1] + : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">) + %2, %pack_2, %empty_unpack_2 = + transform.structured.pack_transpose %pack with_compute_op(%1) + outer_perm = [1, 0] inner_perm = [1, 0] + : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">) + -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op) +}