diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3187,6 +3187,7 @@ def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; def SPV_OC_OpVectorInsertDynamic : I32EnumAttrCase<"OpVectorInsertDynamic", 78>; +def SPV_OC_OpVectorShuffle : I32EnumAttrCase<"OpVectorShuffle", 79>; def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; @@ -3327,7 +3328,7 @@ SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic, - SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td @@ -165,17 +165,17 @@ let builders = [ OpBuilderDAG<(ins "Value":$object, "Value":$composite, - "ArrayRef":$indices)> + "ArrayRef":$indices)> ]; } // ----- -def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", - [NoSideEffect, - TypesMatchWith<"type of 'result' matches element type of 'vector'", - "vector", "result", - "$_self.cast().getElementType()">]> { +def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", [ + NoSideEffect, + TypesMatchWith<"type of 'result' matches element type of 'vector'", + "vector", "result", + "$_self.cast().getElementType()">]> { let summary = [{ Extract a single, dynamically selected, component of a vector. }]; @@ -194,13 +194,6 @@ - ``` - scalar-type ::= integer-type | float-type | boolean-type - vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]` - `:` `vector<` integer-literal `x` scalar-type `>` `,` - integer-type - ```mlir - #### Example: ``` @@ -226,12 +219,13 @@ // ----- -def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", - [NoSideEffect, - TypesMatchWith<"type of 'component' matches element type of 'vector'", - "vector", "component", - "$_self.cast().getElementType()">, - AllTypesMatch<["vector", "result"]>]> { +def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", [ + NoSideEffect, + TypesMatchWith< + "type of 'component' matches element type of 'vector'", + "vector", "component", + "$_self.cast().getElementType()">, + AllTypesMatch<["vector", "result"]>]> { let summary = [{ Make a copy of a vector, with a single, variably selected, component modified. @@ -289,4 +283,64 @@ // ----- +def SPV_VectorShuffleOp : SPV_Op<"VectorShuffle", [ + NoSideEffect, AllElementTypesMatch<["vector1", "vector2", "result"]>]> { + let summary = [{ + Select arbitrary components from two vectors to make a new vector. + }]; + + let description = [{ + Result Type must be an OpTypeVector. The number of components in Result + Type must be the same as the number of Component operands. + + Vector 1 and Vector 2 must both have vector types, with the same + Component Type as Result Type. They do not have to have the same number + of components as Result Type or with each other. They are logically + concatenated, forming a single vector with Vector 1’s components + appearing before Vector 2’s. The components of this logical vector are + logically numbered with a single consecutive set of numbers from 0 to N + - 1, where N is the total number of components. + + Components are these logical numbers (see above), selecting which of the + logically numbered components form the result. Each component is an + unsigned 32-bit integer. They can select the components in any order + and can repeat components. The first component of the result is selected + by the first Component operand, the second component of the result is + selected by the second Component operand, etc. A Component literal may + also be FFFFFFFF, which means the corresponding result component has no + source and is undefined. All Component literals must either be FFFFFFFF + or in [0, N - 1] (inclusive). + + Note: A vector “swizzle” can be done by using the vector for both Vector + operands, or using an OpUndef for one of the Vector operands. + + + + #### Example: + + ```mlir + %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32] + %vector1: vector<4xf32>, %vector2: vector<2xf32> + -> vector<3xf32> + ``` + }]; + + let arguments = (ins + SPV_Vector:$vector1, + SPV_Vector:$vector2, + I32ArrayAttr:$components + ); + + let results = (outs + SPV_Vector:$result + ); + + let assemblyFormat = [{ + attr-dict $components $vector1 `:` type($vector1) `,` + $vector2 `:` type($vector2) `->` type($result) + }]; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_COMPOSITE_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -3036,6 +3037,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) { + VectorType resultType = shuffleOp.getType().cast(); + + size_t numResultElements = resultType.getNumElements(); + if (numResultElements != shuffleOp.components().size()) + return shuffleOp.emitOpError("result type element count (") + << numResultElements + << ") mismatch with the number of component selectors (" + << shuffleOp.components().size() << ")"; + + size_t totalSrcElements = + shuffleOp.vector1().getType().cast().getNumElements() + + shuffleOp.vector2().getType().cast().getNumElements(); + + for (const auto &selector : + shuffleOp.components().getAsValueRange()) { + uint32_t index = selector.getZExtValue(); + if (index >= totalSrcElements && + index != std::numeric_limits().max()) + return shuffleOp.emitOpError("component selector ") + << index << " out of range: expected to be in [0, " + << totalSrcElements << ") or 0xffffffff"; + } + return success(); +} + //===----------------------------------------------------------------------===// // spv.CooperativeMatrixLoadNV //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -283,3 +283,31 @@ %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 return %0 : vector<4xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32> + %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}} + %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}} + %0 = spv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir --- a/mlir/test/Target/SPIRV/composite-op.mlir +++ b/mlir/test/Target/SPIRV/composite-op.mlir @@ -21,4 +21,9 @@ %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 spv.ReturnValue %0: vector<4xf32> } + spv.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32> + %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> + } }