Index: mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3176,6 +3176,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; +def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; @@ -3309,8 +3310,9 @@ SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, + SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, Index: mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -529,4 +529,59 @@ // ----- +def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", + [NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'", + "vector", "result", + "$_self.cast().getElementType()">]> { + let summary = [{ + Extract a single, dynamically selected, component of a vector. + }]; + + let description = [{ + Result Type must be a scalar type. + + Vector must have a type OpTypeVector whose Component Type is Result + Type. + + Index must be a scalar integer. It is interpreted as a 0-based index of + which component of Vector to extract. + + Behavior is undefined if Index's value is less than zero or greater than + or equal to the number of components in Vector. + + + + ``` + scalar-type ::= integer-type | float-type | boolean-type + vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]` + `:` `vector<` integer-literal `x` scalar-type `>` `,` + integer-type + ```mlir + + #### Example: + + ``` + %2 = spv.VectorExtractDynamic %0[%1] : vector<8xf32>, i32 + ``` + }]; + + let arguments = (ins + SPV_Vector:$vector, + SPV_Integer:$index + ); + + let results = (outs + SPV_Type:$result + ); + + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) + }]; + +} + +// ----- + #endif // SPIRV_OPS Index: mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -78,13 +78,31 @@ return success(); } }; + +struct VectorExtractElementOpConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::ExtractElementOp extractElementOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + vector::ExtractElementOp::Adaptor adaptor(operands); + Value newExtractElement = rewriter.create( + extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), + extractElementOp.position()); + rewriter.replaceOp(extractElementOp, newExtractElement); + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); + VectorInsertOpConvert, VectorExtractElementOpConvert>( + context, typeConverter); } namespace { Index: mlir/test/Conversion/VectorToSPIRV/simple.mlir =================================================================== --- mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -21,3 +21,13 @@ %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32> spv.Return } + +// ----- + +// CHECK-LABEL: extract_element +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func @extract_element(%arg0 : vector<4xf32>, %id : i32) { + %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> + spv.ReturnValue %0: f32 +} Index: mlir/test/Dialect/SPIRV/Serialization/vector-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SPIRV/Serialization/vector-ops.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK-LABEL: @vector_dynamic_extract + spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" { + // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 + spv.ReturnValue %0: f32 + } +} Index: mlir/test/Dialect/SPIRV/vector-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SPIRV/vector-ops.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.VectorExtractDynamic +//===----------------------------------------------------------------------===// + +func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 { + // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 + return %0 : f32 +} + +// -----