diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -925,15 +925,21 @@ auto targetShape = linalgOp.getStaticLoopRanges(); auto inputShape = cast(extractOp.getTensor().getType()); - // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. + // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast. if (inputShape.getShape().empty()) return VectorMemoryAccessKind::ScalarBroadcast; + // 0.2 In the case of dynamic shapes just bail-out and assume that it's a + // gather load. + // TODO: Relax this condition. + if (linalgOp.hasDynamicShape()) + return VectorMemoryAccessKind::Gather; // 1. Assume that it's a gather load when reading _into_: // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. // TODO: Relax these conditions. + // FIXME: This condition assumes non-dynamic sizes. if ((llvm::count_if(targetShape, [](int64_t dimSize) { return dimSize > 1; }) != 1) || targetShape.back() == 1) diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir @@ -228,3 +228,44 @@ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op } + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @tensor_extract_dynamic_shape(%arg1: tensor<123x321xf32>, %arg2: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x?x8xf32>) + { + ^bb0(%arg3: f32): + %idx_0 = linalg.index 0 : index + %idx_1 = linalg.index 1 : index + %idx = arith.addi %idx_0, %idx_1 : index + %7 = tensor.extract %arg1[%c0, %idx] : tensor<123x321xf32> + linalg.yield %7 : f32 + } -> tensor<1x?x8xf32> + return %2 : tensor<1x?x8xf32> +} + +// CHECK-LABEL: func.func @tensor_extract_dynamic_shape( +// CHECK-SAME: %[[ARG_1:.*]]: tensor<123x321xf32>, +// CHECK-SAME: %[[ARG_2:.*]]: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> { +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_2]], %[[C1_2]] : tensor<1x?x8xf32> +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1_1]], %[[DIM]], %[[C8]] : vector<1x3x8xi1> +// CHECK: %[[MASK_2:.*]] = arith.constant dense : vector<1x3x8xi1> +// CHECK: %[[FALLTHROUGH:.*]] = arith.constant dense<0.000000e+00> : vector<1x3x8xf32> +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[MASK]] { vector.gather %[[ARG_1]][%[[C0_1]], %[[C0_1]]] [%{{.*}}], %[[MASK_2]], %[[FALLTHROUGH]] : tensor<123x321xf32>, vector<1x3x8xindex>, vector<1x3x8xi1>, vector<1x3x8xf32> into vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32> + +transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op +}