diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -425,6 +425,15 @@ resultSizes = llvm::to_vector(sizes); return success(); } + + FailureOr generateResultTileValue(Operation *op, OpBuilder &b, + unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + return getTiledImplementation(op, b, offsets, sizes) + .back() + ->getResult(resultNumber); + } }; } // namespace diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -91,3 +91,26 @@ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} %2, %loops_2 = transform.structured.tile %1 [0, 4] } + +// ----- + +// CHECK-LABEL: func.func @unpack_elemwise +// CHECK: %[[RES:.*]] = scf.for +// CHECK: scf.for +// CHECK: tensor.unpack +// CHECK: linalg.elemwise_unary +// CHECK: return %[[RES]] +func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf32>) -> tensor<128x384xf32> { + %0 = tensor.empty() : tensor<128x384xf32> + %1 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0 + : tensor<16x48x8x8xf32> -> tensor<128x384xf32> + %2 = linalg.elemwise_unary ins(%1: tensor<128x384xf32>) + outs(%arg1: tensor<128x384xf32>) -> tensor<128x384xf32> + return %2 : tensor<128x384xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} +}