diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -196,6 +196,12 @@ LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (isa(insertOp.getSourceType())) + return rewriter.notifyMatchFailure(insertOp, "unsupported vector source"); + if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) + return rewriter.notifyMatchFailure(insertOp, + "unsupported dest vector type"); + // Special case for inserting scalar values into size-1 vectors. if (insertOp.getSourceType().isIntOrFloat() && insertOp.getDestVectorType().getNumElements() == 1) { @@ -203,9 +209,6 @@ return success(); } - if (isa(insertOp.getSourceType()) || - !spirv::CompositeType::isValid(insertOp.getDestVectorType())) - return failure(); int32_t id = getFirstIntValue(insertOp.getPosition()); rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); @@ -413,9 +416,10 @@ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldResultType = shuffleOp.getResultVectorType(); - if (!spirv::CompositeType::isValid(oldResultType)) - return failure(); Type newResultType = getTypeConverter()->convertType(oldResultType); + if (!newResultType) + return rewriter.notifyMatchFailure(shuffleOp, + "unsupported result vector type"); auto oldSourceType = shuffleOp.getV1VectorType(); if (oldSourceType.getNumElements() > 1) { diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -183,6 +183,15 @@ // ----- +// CHECK-LABEL: @insert_index_vector +// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32> +func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> { + %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex> + return %1: vector<4xindex> +} + +// ----- + // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] @@ -402,6 +411,18 @@ // ----- +// CHECK-LABEL: func @shuffle_index_vector +// CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] +// CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32> +func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex> + return %shuffle : vector<4xindex> +} + +// ----- + // CHECK-LABEL: func @shuffle // CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32> // CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>