diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -94,14 +94,13 @@ return loopBounds; } -static void applyInversePermToRange(SmallVector &offsets, - SmallVector &sizes, - ArrayRef permutation) { +static void applyPermToRange(SmallVector &offsets, + SmallVector &sizes, + ArrayRef permutation) { if (permutation.empty()) return; - SmallVector inversedPerm = invertPermutationVector(permutation); - applyPermutationToVector(offsets, inversedPerm); - applyPermutationToVector(sizes, inversedPerm); + applyPermutationToVector(offsets, permutation); + applyPermutationToVector(sizes, permutation); } struct PackOpTiling @@ -133,7 +132,8 @@ int64_t inputRank = packOp.getSourceRank(); SmallVector origOffsets(offsets.begin(), offsets.end()); SmallVector origSizes(sizes.begin(), sizes.end()); - applyInversePermToRange(origOffsets, origSizes, packOp.getOuterDimsPerm()); + applyPermToRange(origOffsets, origSizes, + invertPermutationVector(packOp.getOuterDimsPerm())); DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); @@ -382,8 +382,8 @@ // The tiling is applied on destination dimensions. We have to apply the // interchange on source dimensions if outer_dims_perm is set. - applyInversePermToRange(sliceSrcIndices, sliceSrcSizes, - unpackOp.getOuterDimsPerm()); + applyPermToRange(sliceSrcIndices, sliceSrcSizes, + unpackOp.getOuterDimsPerm()); Attribute zeroAttr = b.getIndexAttr(0); sliceSrcIndices.append(numInnerTiles, zeroAttr); sliceSrcSizes.append(unpackOp.getMixedTiles()); diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -549,6 +549,7 @@ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] // CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]] // CHECK: scf.yield %[[RES]] + func.func @dynamic_perfect_CKkc_to_KC(%source: tensor, %dest: tensor) -> tensor { %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %dest : tensor -> tensor return %0 : tensor @@ -559,3 +560,77 @@ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4] } + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0 floordiv 2)> +// CHECK: func.func @perfect_NKPQk_to_NPQK( +// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x4x6x6x2xf32>, +// CHECK-SAME: %{{.+}}: tensor<1x6x6x8xf32>) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %{{.+}} = scf.for %[[P:.+]] = %[[C0]] to %[[C6]] step %[[C1]] +// CHECK: %{{.+}} = scf.for %[[Q:.+]] = %[[C0]] to %[[C6]] step %[[C1]] +// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C4]] +// CHECK: %[[K_SZ:.+]] = affine.apply #[[MAP]](%[[K]]) +// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[K_SZ]], %[[P]], %[[Q]], 0] +// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[P]], %[[Q]], %[[K]]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack +// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] +// CHECK-SAME: into %[[SLICE_DEST]] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] +// CHECK-SAME: into %{{.+}}[0, %[[P]], %[[Q]], %[[K]]] +// CHECK: scf.yield %[[RES]] + +func.func @perfect_NKPQk_to_NPQK(%source: tensor<1x4x6x6x2xf32>, %dest: tensor<1x6x6x8xf32>) -> tensor<1x6x6x8xf32> { + %0 = tensor.unpack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x4x6x6x2xf32> -> tensor<1x6x6x8xf32> + return %0 : tensor<1x6x6x8xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 + %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 4] +} + +// ----- + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 6, 1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -2 + 8, 2)> +// CHECK: func.func @perfect_NPQK_to_NKPQk +// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x6x6x8xf32>, +// CHECK-SAME: %{{.+}}: tensor<1x4x6x6x2xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index +// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]] +// CHECK: %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]] +// CHECK: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]] +// CHECK: %[[MIN_ARG4:.+]] = affine.min #[[MAP]](%[[ARG4]]) +// CHECK: %[[MIN_ARG6:.+]] = affine.min #[[MAP]](%[[ARG6]]) +// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]]) +// CHECK: %[[MIN_ARG2:.+]] = affine.min #[[MAP2]](%[[ARG2]]) +// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]] +// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0] +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] +// CHECK-SAME: into %[[SLICE_DEST]] +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[PACK]] +// CHECK-SAME: into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0] +// CHECK: scf.yield %[[RES]] + +func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> { + %0 = tensor.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32> + return %0 : tensor<1x4x6x6x2xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 + %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1] +}