Index: mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3177,6 +3177,7 @@ def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; +def SPV_OC_OpVectorInsertDynamic : I32EnumAttrCase<"OpVectorInsertDynamic", 78>; def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; @@ -3310,9 +3311,9 @@ SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, - SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, - SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct, + SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, + SPV_OC_OpVectorInsertDynamic, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, Index: mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -171,7 +171,7 @@ // ----- def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", - [NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'", + [NoSideEffect, TypesMatchWith<"type of 'result' matches element type of 'vector'", "vector", "result", "$_self.cast().getElementType()">]> { let summary = [{ @@ -225,4 +225,66 @@ // ----- +def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", + [NoSideEffect, TypesMatchWith<"type of 'component' matches element type of 'vector'", + "vector", "component", + "$_self.cast().getElementType()">, + AllTypesMatch<["vector", "result"]>]> { + let summary = [{ + Make a copy of a vector, with a single, variably selected, component + modified. + }]; + + let description = [{ + Result Type must be an OpTypeVector. + + Vector must have the same type as Result Type and is the vector that the + non-written components are copied from. + + Component is the value supplied for the component selected by Index. It + must have the same type as the type of components in Result Type. + + Index must be a scalar integer. It is interpreted as a 0-based index of + which component to modify. + + Behavior is undefined if Index's value is less than zero or greater than + or equal to the number of components in Vector. + + + + ``` + scalar-type ::= integer-type | float-type | boolean-type + vector-insert-dynamic-op ::= `spv.VectorInsertDynamic ` ssa-use `,` + ssa-use `[` ssa-use `]` + `:` `vector<` integer-literal `x` scalar-type `>` `,` + integer-type + ```mlir + + #### Example: + + ``` + %scalar = ... : f32 + %2 = spv.VectorInsertDynamic %scalar %0[%1] : f32, vector<8xf32>, i32 + ``` + }]; + + let arguments = (ins + SPV_Vector:$vector, + SPV_Scalar:$component, + SPV_Integer:$index + ); + + let results = (outs + SPV_Vector:$result + ); + + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + $component `,` $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) + }]; +} + +// ----- + #endif // SPIRV_COMPOSITE_OPS Index: mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -97,14 +97,32 @@ } }; +struct VectorInsertElementOpConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::InsertElementOp insertElementOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) + return failure(); + vector::InsertElementOp::Adaptor adaptor(operands); + Value newInsertElement = rewriter.create( + insertElementOp.getLoc(), insertElementOp.getType(), + insertElementOp.dest(), adaptor.source(), insertElementOp.position()); + rewriter.replaceOp(insertElementOp, newInsertElement); + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert( - context, typeConverter); + VectorInsertOpConvert, VectorExtractElementOpConvert, + VectorInsertElementOpConvert>(context, typeConverter); } namespace { Index: mlir/test/Conversion/VectorToSPIRV/simple.mlir =================================================================== --- mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -39,3 +39,21 @@ %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> spv.ReturnValue %0: f32 } + +// ----- + +// CHECK-LABEL: insert_element +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) { + %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> + spv.ReturnValue %0: vector<4xf32> +} + +// ----- + +func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) { +// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}} + %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> + spv.Return +} Index: mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -16,4 +16,9 @@ %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 spv.ReturnValue %0: f32 } + spv.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> "None" { + // CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 + spv.ReturnValue %0: vector<4xf32> + } } Index: mlir/test/Dialect/SPIRV/composite-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/composite-ops.mlir +++ mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -273,3 +273,13 @@ %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 return %0 : f32 } + +//===----------------------------------------------------------------------===// +// spv.VectorInsertDynamic +//===----------------------------------------------------------------------===// + +func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> { + // CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 + return %0 : vector<4xf32> +}