diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -441,7 +441,7 @@ return false; // TODO: relax the restrictions on indexing map. for (OpOperand *opOperand : linalgOp.getOutputOperands()) { - if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity()) + if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation()) return false; } return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -121,6 +121,26 @@ // ----- +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK: func @generic_interchanged_transpose +func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> { + // CHECK: %[[IN:.+]] = vector.transfer_read + // CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]] + %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32> + %1 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<12x128x32xf32>) + outs(%0 : tensor<128x12x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<128x12x32xf32> + return %1 : tensor<128x12x32xf32> +} + +// ----- + #matmul_trait = { args_in = 2, args_out = 1,