diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -220,6 +220,32 @@ return success(); } + + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + auto packOp = cast(op); + int64_t numTiles = packOp.getInnerDimsPos().size(); + + // tensor.pack op is fusible (as a producer) only if full inner tiles are + // iterated or inner dims are not tiled. Otherwise, it will generate a + // sequence of non-trivial ops (for partial tiles). + for (auto offset : offsets.take_back(numTiles)) + if (!isConstantIntValue(offset, 0)) + return failure(); + + for (auto iter : + llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) + if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) + return failure(); + + FailureOr tilingResult = getTiledImplementation( + op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); + if (failed(tilingResult)) + return failure(); + return tilingResult.value(); + } }; struct UnpackTileDimInfo { diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -118,3 +118,51 @@ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } + +// ----- + +// CHECK-LABEL: func.func @pack_elemwise +// CHECK: %[[RES:.*]] = scf.for +// CHECK: scf.for +// CHECK: tensor.pack +// CHECK: linalg.elemwise_unary +// CHECK: return %[[RES]] +func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> { + %0 = tensor.empty() : tensor<16x48x8x8xf32> + %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 + : tensor<128x384xf32> -> tensor<16x48x8x8xf32> + %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>) + outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> + return %2 : tensor<16x48x8x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +} + +// ----- + +// CHECK-LABEL: func.func @nofuse_pack_elemwise +// CHECK: tensor.pack +// CHECK: %[[RES:.*]] = scf.for +// CHECK: scf.for +// CHECK: linalg.elemwise_unary +// CHECK: return %[[RES]] +func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> { + %0 = tensor.empty() : tensor<16x48x8x8xf32> + %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 + : tensor<128x384xf32> -> tensor<16x48x8x8xf32> + %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>) + outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> + return %2 : tensor<16x48x8x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +}