diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3187,6 +3187,7 @@ def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; def SPV_OC_OpVectorInsertDynamic : I32EnumAttrCase<"OpVectorInsertDynamic", 78>; +def SPV_OC_OpVectorShuffle : I32EnumAttrCase<"OpVectorShuffle", 79>; def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; @@ -3327,7 +3328,7 @@ SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic, - SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td @@ -289,4 +289,60 @@ // ----- +def SPV_VectorShuffleOp : SPV_Op<"VectorShuffle", [NoSideEffect]> { + let summary = [{ + Select arbitrary components from two vectors to make a new vector. + }]; + + let description = [{ + Result Type must be an OpTypeVector. The number of components in Result + Type must be the same as the number of Component operands. + + Vector 1 and Vector 2 must both have vector types, with the same + Component Type as Result Type. They do not have to have the same number + of components as Result Type or with each other. They are logically + concatenated, forming a single vector with Vector 1’s components + appearing before Vector 2’s. The components of this logical vector are + logically numbered with a single consecutive set of numbers from 0 to N + - 1, where N is the total number of components. + + Components are these logical numbers (see above), selecting which of the + logically numbered components form the result. Each component is an + unsigned 32-bit integer. They can select the components in any order + and can repeat components. The first component of the result is selected + by the first Component operand, the second component of the result is + selected by the second Component operand, etc. A Component literal may + also be FFFFFFFF, which means the corresponding result component has no + source and is undefined. All Component literals must either be FFFFFFFF + or in [0, N - 1] (inclusive). + + Note: A vector “swizzle” can be done by using the vector for both Vector + operands, or using an OpUndef for one of the Vector operands. + + + + #### Example: + + ```mlir + %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + ``` + }]; + + let arguments = (ins + SPV_Vector:$vector1, + SPV_Vector:$vector2, + I32ArrayAttr:$components + ); + + let results = (outs + SPV_Type:$result + ); + + let assemblyFormat = [{ + attr-dict $components $vector1 `:` type($vector1) `,` $vector2 `:` type($vector2) `->` type($result) + }]; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_COMPOSITE_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3036,6 +3036,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) { + VectorType resultType = shuffleOp.getType().cast(); + + size_t numResultElements = resultType.getNumElements(); + if (numResultElements != shuffleOp.components().size()) + return shuffleOp.emitOpError("result type element count (") + << numResultElements + << ") mismatch with the number of component selectors (" + << shuffleOp.components().size() << ")"; + + size_t totalSrcElements = + shuffleOp.vector1().getType().cast().getNumElements() + + shuffleOp.vector2().getType().cast().getNumElements(); + + for (const auto &selector : + shuffleOp.components().getAsValueRange()) { + uint32_t index = selector.getZExtValue(); + if (index >= totalSrcElements && + index != std::numeric_limits().max()) + return shuffleOp.emitOpError("component selector ") + << index << " out of range: expected to be in [0, " + << totalSrcElements << ") or 0xffffffff"; + } + return success(); +} + //===----------------------------------------------------------------------===// // spv.CooperativeMatrixLoadNV //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -283,3 +283,31 @@ %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 return %0 : vector<4xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32> + %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}} + %0 = spv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> { + // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}} + %0 = spv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + return %0: vector<3xf32> +} diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir --- a/mlir/test/Target/SPIRV/composite-op.mlir +++ b/mlir/test/Target/SPIRV/composite-op.mlir @@ -21,4 +21,9 @@ %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32 spv.ReturnValue %0: vector<4xf32> } + spv.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}} : vector<4xf32>, %arg1 : vector<2xf32> -> vector<3xf32> + %0 = spv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1: vector<4xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> + } }