diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -745,8 +745,12 @@ if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) return failure(); - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) - return failure(); + // Check the index type, but only for non 0-d tensors (for which we do need + // access indices). + if (not extractOp.getIndices().empty()) { + if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) + return failure(); + } if (llvm::any_of(extractOp->getResultTypes(), [](Type type) { return !VectorType::isValidElementType(type); @@ -919,6 +923,12 @@ LinalgOp &linalgOp) { auto targetShape = linalgOp.getStaticLoopRanges(); + auto inputShape = cast(extractOp.getTensor().getType()); + + // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. + if (inputShape.getShape().empty()) + return VectorMemoryAccessKind::ScalarBroadcast; + // 1. Assume that it's a gather load when reading _into_: // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or @@ -929,7 +939,6 @@ targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = cast(extractOp.getTensor().getType()); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -499,3 +499,30 @@ %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @vectorize_0d_tensor_extract(%arg0: tensor, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %7 = tensor.extract %arg0[] : tensor + linalg.yield %7 : f32 + } -> tensor<1x1x3xf32> + return %2 : tensor<1x1x3xf32> +} + +// CHECK-LABEL: func.func @vectorize_0d_tensor_extract( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor +// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32> + +transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op + }