diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -41,7 +41,7 @@ LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); + Type dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); @@ -60,15 +60,21 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, + matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (broadcastOp.getSource().getType().isa() || - !spirv::CompositeType::isValid(broadcastOp.getVectorType())) + Type resultType = getTypeConverter()->convertType(castOp.getVectorType()); + if (!resultType) return failure(); - SmallVector source(broadcastOp.getVectorType().getNumElements(), + + if (resultType.isa()) { + rewriter.replaceOp(castOp, adaptor.getSource()); + return success(); + } + + SmallVector source(castOp.getVectorType().getNumElements(), adaptor.getSource()); rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), source); + castOp, castOp.getVectorType(), source); return success(); } }; @@ -85,7 +91,7 @@ if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); - auto dstType = getTypeConverter()->convertType(extractOp.getType()); + Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); @@ -108,7 +114,7 @@ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = getTypeConverter()->convertType(extractOp.getType()); + Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); @@ -183,13 +189,21 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, + matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) + Type vectorType = + getTypeConverter()->convertType(adaptor.getVector().getType()); + if (!vectorType) return failure(); + + if (vectorType.isa()) { + rewriter.replaceOp(extractOp, adaptor.getVector()); + return success(); + } + rewriter.replaceOpWithNewOp( - extractElementOp, extractElementOp.getType(), adaptor.getVector(), - extractElementOp.getPosition()); + extractOp, extractOp.getType(), adaptor.getVector(), + extractOp.getPosition()); return success(); } }; @@ -199,13 +213,20 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, + matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) + Type vectorType = getTypeConverter()->convertType(insertOp.getType()); + if (!vectorType) return failure(); + + if (vectorType.isa()) { + rewriter.replaceOp(insertOp, adaptor.getSource()); + return success(); + } + rewriter.replaceOpWithNewOp( - insertElementOp, insertElementOp.getType(), insertElementOp.getDest(), - adaptor.getSource(), insertElementOp.getPosition()); + insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), + insertOp.getPosition()); return success(); } }; @@ -354,7 +375,7 @@ auto oldResultType = shuffleOp.getVectorType(); if (!spirv::CompositeType::isValid(oldResultType)) return failure(); - auto newResultType = getTypeConverter()->convertType(oldResultType); + Type newResultType = getTypeConverter()->convertType(oldResultType); auto oldSourceType = shuffleOp.getV1VectorType(); if (oldSourceType.getNumElements() > 1) { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -235,7 +235,7 @@ const SPIRVTypeConverter::Options &options, VectorType type, Optional storageClass = {}) { - if (type.getRank() == 1 && type.getNumElements() == 1) + if (type.getRank() <= 1 && type.getNumElements() == 1) return type.getElementType(); if (!spirv::CompositeType::isValid(type)) { diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -240,6 +240,10 @@ // CHECK-SAME: %{{.+}}: i32 func.func @one_element_vector(%arg0: vector<1xi32>) { return } +// CHECK-LABEL: spv.func @zerod_vector +// CHECK-SAME: %{{.+}}: f32 +func.func @zerod_vector(%arg0: vector) { return } + } // end module // ----- diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -121,6 +121,28 @@ // ----- +// CHECK-LABEL: @extract_element_size1_vector +// CHECK-SAME: (%[[S:.+]]: f32 +func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { + %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> + %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> + // CHECK: return %[[S]] + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_element_0d_vector +// CHECK-SAME: (%[[S:.+]]: f32) +func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { + %bcast = vector.broadcast %arg0 : f32 to vector + %0 = vector.extractelement %bcast[] : vector + // CHECK: return %[[S]] + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> @@ -161,6 +183,28 @@ // ----- +// CHECK-LABEL: @insert_element_size1_vector +// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 +func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { + %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> + // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32> + // CHECK: return %[[V]] + return %0: vector<1xf32> +} + +// ----- + +// CHECK-LABEL: @insert_element_0d_vector +// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 +func.func @insert_element_0d_vector(%scalar: f32, %vector : vector) -> vector { + %0 = vector.insertelement %scalar, %vector[] : vector + // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector + // CHECK: return %[[V]] + return %0: vector +} + +// ----- + // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> // CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>