diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -402,9 +402,12 @@ unpackOp.getDestType().getElementType()); } - Operation *tiledUnpackOp = - b.create(loc, TypeRange{sliceDest.getType()}, - ValueRange{sliceSource, sliceDest}, op->getAttrs()); + SmallVector tiledOperands = {sliceSource, sliceDest}; + for (auto tile : unpackOp.getInnerTiles()) + tiledOperands.push_back(tile); + + Operation *tiledUnpackOp = b.create( + loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) return {tiledUnpackOp}; diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -580,8 +580,8 @@ // CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[P]], %[[Q]], %[[K]]] // CHECK: %[[UNPACK:.+]] = tensor.unpack // CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] -// CHECK-SAME: into %[[SLICE_DEST]] -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] +// CHECK-SAME: into %[[SLICE_DEST]] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] // CHECK-SAME: into %{{.+}}[0, %[[P]], %[[Q]], %[[K]]] // CHECK: scf.yield %[[RES]] @@ -598,6 +598,32 @@ // ----- +func.func private @get_dynamic_tile_size() -> index + +// CHECK-LABEL: func.func @fully_dynamic_unpack +// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] +// CHECK-SAME: %[[DST:[0-9a-zA-Z]+]] +// CHECK: %[[INNER_TS:.+]] = call @get_dynamic_tile_size() : () -> index +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[DST]]) +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]] +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [%[[INNER_TS]], %[[INNER_TS]]] into %[[EMPTY]] +func.func @fully_dynamic_unpack(%source: tensor, %dest: tensor) -> tensor { + %0 = func.call @get_dynamic_tile_size() : () -> index + %1 = tensor.unpack %source inner_dims_pos = [1, 0] inner_tiles = [%0, %0] into %dest : tensor -> tensor + return %1 : tensor +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1, %loops:2 = transform.structured.tile_to_scf_for %0 [4, 8] +} + +// ----- + // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 6, 1)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -2 + 8, 2)>