diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -564,6 +564,40 @@ // ----- +def SPV_VectorTimesScalarOp : SPV_Op<"VectorTimesScalar", [NoSideEffect]> { + let summary = "Scale a floating-point vector."; + + let description = [{ + Result Type must be a vector of floating-point type. + + The type of Vector must be the same as Result Type. Each component of + Vector is multiplied by Scalar. + + Scalar must have the same type as the Component Type in Result Type. + + + + #### Example: + + ```mlir + %0 = spv.VectorTimesScalar %vector, %scalar : vector<4xf32> + ``` + }]; + + let arguments = (ins + VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$vector, + SPV_Float:$scalar + ); + + let results = (outs + VectorOfLengthAndType<[2, 3, 4], [SPV_Float]>:$result + ); + + let assemblyFormat = "operands attr-dict `:` `(` type(operands) `)` `->` type($result)"; +} + +// ----- + def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp, UsableInSpecConstantOp]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4078,6 +4078,7 @@ def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; @@ -4202,32 +4203,33 @@ SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, - SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, - SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, - SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, - SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, - SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, + SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, + SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, + SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, + SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, + SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, + SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, + SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, + SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, + SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, + SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, + SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, + SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, + SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, + SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, + SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, + SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange, + SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak, + SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, + SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, + SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, + SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, + SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, + SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, + SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4424,6 +4424,19 @@ return verifyAccessChain(*this, indices()); } +//===----------------------------------------------------------------------===// +// spv.VectorTimesScalarOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::VectorTimesScalarOp::verify() { + if (vector().getType() != getType()) + return emitOpError("vector operand and result type mismatch"); + auto scalarType = getType().cast().getElementType(); + if (scalar().getType() != scalarType) + return emitOpError("scalar operand and result element type match"); + return success(); +} + // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -219,3 +219,29 @@ return %0 : i32 } +// ----- +//===----------------------------------------------------------------------===// +// spv.VectorTimesScalar +//===----------------------------------------------------------------------===// + +func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<4xf32> { + // CHECK: spv.VectorTimesScalar %{{.+}}, %{{.+}} : (vector<4xf32>, f32) -> vector<4xf32> + %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f16) -> vector<4xf32> { + // expected-error @+1 {{scalar operand and result element type match}} + %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f16) -> vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3xf32> { + // expected-error @+1 {{vector operand and result type mismatch}} + %0 = spv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32> + return %0 : vector<3xf32> +} diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir --- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir @@ -81,4 +81,9 @@ %0 = spv.SRem %arg0, %arg1 : vector<4xi32> spv.Return } + spv.func @vector_times_scalar(%arg0 : vector<4xf32>, %arg1 : f32) "None" { + // CHECK: {{%.*}} = spv.VectorTimesScalar {{%.*}}, {{%.*}} : (vector<4xf32>, f32) -> vector<4xf32> + %0 = spv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32> + spv.Return + } }