diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -113,9 +113,6 @@ if (!dstType) return failure(); - // Extract vector<1xT> not supported yet. - if (dstType.isa()) - return failure(); uint64_t offset = getFirstIntValue(extractOp.offsets()); uint64_t size = getFirstIntValue(extractOp.sizes()); @@ -125,6 +122,13 @@ Value srcVector = operands.front(); + // Extract vector<1xT> case. + if (dstType.isa()) { + rewriter.replaceOpWithNewOp(extractOp, + srcVector, offset); + return success(); + } + SmallVector indices(size); std::iota(indices.begin(), indices.end(), offset); diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -91,8 +91,10 @@ // CHECK-LABEL: func @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> +// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32> func @extract_strided_slice(%arg0: vector<4xf32>) { %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> spv.Return }