diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -163,8 +163,13 @@ // Step 1. Construct the information of packing data dimensions; append inner // dimensions to the indexing maps for the operand. for (auto [index, expr] : llvm::enumerate(exprs)) { - int64_t dimPos = expr.cast().getPosition(); - domainDimToOperandDim[dimPos] = index; + if (auto dimExpr = expr.dyn_cast()) { + int64_t dimPos = dimExpr.getPosition(); + domainDimToOperandDim[dimPos] = index; + continue; + } + assert(expr.isa() && + "Found non-constant and non-affine dim expression"); } SmallVector innerDimsPos; SmallVector innerTileSizes; @@ -186,8 +191,13 @@ SmallVector inversedOuterPerm = invertPermutationVector(packInfo.outerDimsOnDomainPerm); for (auto i : llvm::seq(0, origIndexingMap.getNumResults())) { - int64_t dimPos = exprs[i].cast().getPosition(); - exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); + if (auto dimExpr = exprs[i].dyn_cast()) { + int64_t dimPos = dimExpr.getPosition(); + exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); + continue; + } + assert(exprs[i].isa() && + "Attempted to permute non-constant and non-affine dim expression"); } // Step 2.2: Undo the transposition on `exprs` and propagate the // transposition on the pack using outerDimsPerm. diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -231,9 +231,6 @@ // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> -#map2 = affine_map<(d0, d1) -> (d1)> func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> { %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> @@ -280,6 +277,52 @@ // ----- +func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x1x1x1xi32>, %arg2: tensor<1x128x1x1xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> +{ + %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, 0)>, + affine_map<(d0, d1, d2, d3) -> (0, d1, 0, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100x1x1x1xi32>, tensor<1x128x1x1xi32>) + outs(%init_transpose : tensor<100x200x128x256xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %0 = arith.addi %b0, %b1 : i32 + %1 = arith.addi %0, %b2 : i32 + linalg.yield %1 : i32 + } -> tensor<100x200x128x256xi32> + %4 = tensor.pack %transpose + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> + return %4 : tensor<100x200x4x16x16x32xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)> +// CHECK: func.func @affine_constant_expr_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32> +// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [32] +// CHECK-SAME: into %[[ARG2_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] +// CHECK-SAME: outs(%[[DEST]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d1)>