diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -925,15 +925,21 @@ auto targetShape = linalgOp.getStaticLoopRanges(); auto inputShape = cast(extractOp.getTensor().getType()); - // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. + // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast. if (inputShape.getShape().empty()) return VectorMemoryAccessKind::ScalarBroadcast; + // 0.2 In the case of dynamic shapes just bail-out and assume that it's a + // gather load. + // TODO: Relax this condition. + if (linalgOp.hasDynamicShape()) + return VectorMemoryAccessKind::Gather; // 1. Assume that it's a gather load when reading _into_: // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. // TODO: Relax these conditions. + // FIXME: This condition assumes non-dynamic sizes. if ((llvm::count_if(targetShape, [](int64_t dimSize) { return dimSize > 1; }) != 1) || targetShape.back() == 1) diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir @@ -228,3 +228,65 @@ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op } + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @extract_masked_vectorize(%in: tensor<123x321xf32>, %arg1: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg1 : tensor<1x?x8xf32>) + { + ^bb0(%arg3: f32): + %idx_0 = linalg.index 0 : index + %idx_1 = linalg.index 1 : index + %idx = arith.addi %idx_0, %idx_1 : index + %7 = tensor.extract %in[%c0, %idx] : tensor<123x321xf32> + linalg.yield %7 : f32 + } -> tensor<1x?x8xf32> + return %2 : tensor<1x?x8xf32> +} + +// CHECK-LABEL: func.func @extract_masked_vectorize( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<123x321xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<1x?x8xf32> +// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_7]] : vector<1x3x8xi1> +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true, true]} : tensor<1x?x8xf32>, vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32> +// CHECK: %[[VAL_12:.*]] = arith.constant dense<0> : vector<1xindex> +// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_12]] : vector<1xindex> to vector<8x3x1xindex> +// CHECK: %[[VAL_14:.*]] = vector.transpose %[[VAL_13]], [2, 1, 0] : vector<8x3x1xindex> to vector<1x3x8xindex> +// CHECK: %[[VAL_15:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : vector<3xindex> to vector<1x8x3xindex> +// CHECK: %[[VAL_17:.*]] = vector.transpose %[[VAL_16]], [0, 2, 1] : vector<1x8x3xindex> to vector<1x3x8xindex> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_14]], %[[VAL_17]] : vector<1x3x8xindex> +// CHECK: %[[VAL_19:.*]] = arith.constant dense : vector<1x3x8xi1> +// CHECK: %[[VAL_20:.*]] = arith.constant dense<0.000000e+00> : vector<1x3x8xf32> +// CHECK: %[[VAL_21:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_22:.*]] = arith.constant dense<1> : vector<1x3x8xindex> +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_24:.*]] = tensor.dim %[[VAL_0]], %[[VAL_23]] : tensor<123x321xf32> +// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x3x8xindex> +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : vector<1x3x8xindex> +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_26]] : vector<1x3x8xindex> +// CHECK: %[[VAL_28:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_21]], %[[VAL_21]]] {{\[}}%[[VAL_27]]], %[[VAL_19]], %[[VAL_20]] : tensor<123x321xf32>, vector<1x3x8xindex>, vector<1x3x8xi1>, vector<1x3x8xf32> into vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32> +// CHECK: %[[VAL_29:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_30:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_28]], %[[VAL_1]]{{\[}}%[[VAL_29]], %[[VAL_29]], %[[VAL_29]]] {in_bounds = [true, true, true]} : vector<1x3x8xf32>, tensor<1x?x8xf32> } : vector<1x3x8xi1> -> tensor<1x?x8xf32> +// CHECK: return %[[VAL_30]] : tensor<1x?x8xf32> +// CHECK: } + +transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op +}