Index: mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -89,6 +89,11 @@ return failure(); vector::ExtractOp::Adaptor adaptor(operands); + if (adaptor.vector().getType().isa()) { + rewriter.replaceOp(extractOp, adaptor.vector()); + return success(); + } + int32_t id = getFirstIntValue(extractOp.position()); rewriter.replaceOpWithNewOp( extractOp, adaptor.vector(), id); Index: mlir/test/Conversion/VectorToSPIRV/simple.mlir =================================================================== --- mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -40,6 +40,24 @@ // ----- +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + +// CHECK-LABEL: func @extract_scalar +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16> +// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> +// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32 +// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32> +func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) { + %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32> + %1 = vector.extract %0[0] : vector<1xf32> + %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32> + spv.Return +} + +} // end module + +// ----- + // CHECK-LABEL: extract_insert // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>