Index: mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -113,9 +113,6 @@ if (!dstType) return failure(); - // Extract vector<1xT> not supported yet. - if (dstType.isa()) - return failure(); uint64_t offset = getFirstIntValue(extractOp.offsets()); uint64_t size = getFirstIntValue(extractOp.sizes()); @@ -125,6 +122,14 @@ Value srcVector = operands.front(); + // Extract vector<1xT> case. + if (dstType.isa()) { + int32_t id = getFirstIntValue(extractOp.offsets()); + rewriter.replaceOpWithNewOp(extractOp, + srcVector, id); + return success(); + } + SmallVector indices(size); std::iota(indices.begin(), indices.end(), offset); Index: mlir/test/Conversion/VectorToSPIRV/simple.mlir =================================================================== --- mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -91,8 +91,10 @@ // CHECK-LABEL: func @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> +// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32> func @extract_strided_slice(%arg0: vector<4xf32>) { %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> spv.Return }