diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3176,6 +3176,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; +def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; @@ -3309,8 +3310,9 @@ SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, + SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -168,4 +168,61 @@ ]; } +// ----- + +def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", + [NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'", + "vector", "result", + "$_self.cast().getElementType()">]> { + let summary = [{ + Extract a single, dynamically selected, component of a vector. + }]; + + let description = [{ + Result Type must be a scalar type. + + Vector must have a type OpTypeVector whose Component Type is Result + Type. + + Index must be a scalar integer. It is interpreted as a 0-based index of + which component of Vector to extract. + + Behavior is undefined if Index's value is less than zero or greater than + or equal to the number of components in Vector. + + + + ``` + scalar-type ::= integer-type | float-type | boolean-type + vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]` + `:` `vector<` integer-literal `x` scalar-type `>` `,` + integer-type + ```mlir + + #### Example: + + ``` + %2 = spv.VectorExtractDynamic %0[%1] : vector<8xf32>, i32 + ``` + }]; + + let arguments = (ins + SPV_Vector:$vector, + SPV_Integer:$index + ); + + let results = (outs + SPV_Scalar:$result + ); + + let verifier = [{ return success(); }]; + + let assemblyFormat = [{ + $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) + }]; + +} + +// ----- + #endif // SPIRV_COMPOSITE_OPS diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -78,13 +78,33 @@ return success(); } }; + +struct VectorExtractElementOpConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::ExtractElementOp extractElementOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) + return failure(); + vector::ExtractElementOp::Adaptor adaptor(operands); + Value newExtractElement = rewriter.create( + extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), + extractElementOp.position()); + rewriter.replaceOp(extractElementOp, newExtractElement); + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); + VectorInsertOpConvert, VectorExtractElementOpConvert>( + context, typeConverter); } namespace { diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: broadcast // CHECK-SAME: %[[A:.*]]: f32 @@ -21,3 +21,21 @@ %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32> spv.Return } + +// ----- + +// CHECK-LABEL: extract_element +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func @extract_element(%arg0 : vector<4xf32>, %id : i32) { + %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> + spv.ReturnValue %0: f32 +} + +// ----- + +func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) { +// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}} + %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> + spv.ReturnValue %0: f32 +} diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -11,4 +11,9 @@ %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> spv.ReturnValue %0: vector<3xf32> } + spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" { + // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 + spv.ReturnValue %0: f32 + } } diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -261,3 +261,15 @@ %0 = "spv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spv.array<4xf32>) -> !spv.array<4xf64> return %0: !spv.array<4xf64> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.VectorExtractDynamic +//===----------------------------------------------------------------------===// + +func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 { + // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 + %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 + return %0 : f32 +}