diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -452,6 +452,31 @@ // ----- +def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate", SPV_Integer, []> { + let summary = "Signed-integer subtract of Operand from zero."; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Operand’s type must be a scalar or vector of integer type. It must + have the same number of components as Result Type. The component width + must equal the component width in Result Type. + + Results are computed per component. + + + + #### Example: + + ```mlir + %1 = spv.SNegate %0 : i32 + %3 = spv.SNegate %2 : vector<4xi32> + ``` + }]; +} + +// ----- + def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> { let summary = [{ Signed remainder operation for the remainder whose sign matches the sign diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3150,6 +3150,7 @@ def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; +def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; @@ -3271,41 +3272,42 @@ SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, - SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, - SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, - SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, - SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, - SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, - SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, - SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, - SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, - SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, - SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, - SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, - SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, - SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, - SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, - SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak, - SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, - SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, - SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, - SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, - SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, - SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, - SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, - SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, - SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, - SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, - SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, - SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, - SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, - SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, - SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV + SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, + SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, + SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, + SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, + SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, + SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, + SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, + SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, + SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, + SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, + SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, + SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine, + SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, + SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, + SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, + SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, + SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, + SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, + SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, + SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, + SPV_OC_OpCooperativeMatrixLengthNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -26,6 +26,12 @@ SameOperandsAndResultShape])> { let parser = [{ return ::parseLogicalBinaryOp(parser, result); }]; let printer = [{ return ::printLogicalOp(getOperation(), p); }]; + + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &state, Value lhs, Value rhs", + "::buildLogicalBinaryOp(builder, state, lhs, rhs);"> + ]; } class SPV_LogicalUnaryOp(loc, type, lhs); + Value rhsAbs = builder.create(loc, type, rhs); + Value abs = builder.create(loc, lhsAbs, rhsAbs); + + // Fix the sign. + Value isPositive; + if (lhs == signOperand) { + isPositive = builder.create(loc, lhs, lhsAbs); + } else { + isPositive = builder.create(loc, rhs, rhsAbs); + } + Value absNegate = builder.create(loc, type, abs); + return builder.create(loc, type, isPositive, abs, absNegate); +} + /// Returns the offset of the value in `targetBits` representation. `srcIdx` is /// an index into a 1-D array with each element having `sourceBits`. When /// accessing an element in the array treating as having elements of @@ -308,6 +338,15 @@ } }; +class SignedRemIOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts bitwise standard operations to SPIR-V operations. This is a special /// pattern other than the BinaryOpPatternPattern because if the operands are /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For @@ -506,6 +545,24 @@ } // namespace +//===----------------------------------------------------------------------===// +// SignedRemIOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult SignedRemIOpPattern::matchAndRewrite( + SignedRemIOp remOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto dstType = this->typeConverter.convertType(remOp.getType()); + if (!dstType) + return failure(); + + Value result = emulateSignedRemainder(remOp.getLoc(), operands[0], + operands[1], operands[0], rewriter); + rewriter.replaceOp(remOp, result); + + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// @@ -1005,6 +1062,9 @@ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< + // Unary and binary patters + BitwiseOpPattern, + BitwiseOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, @@ -1020,7 +1080,6 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, @@ -1031,19 +1090,28 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - AllocOpPattern, DeallocOpPattern, - BitwiseOpPattern, - BitwiseOpPattern, - BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, - CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern, - ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern, + SignedRemIOpPattern, XOrOpPattern, + + // Comparison patterns + BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern, + + // Constant patterns + ConstantCompositeOpPattern, ConstantScalarOpPattern, + + // Memory patterns + AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, + LoadOpPattern, StoreOpPattern, + + ReturnOpPattern, SelectOpPattern, + + // Type cast patterns ZeroExtendI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, - TypeCastingOpPattern, XOrOpPattern>( - context, typeConverter); + TypeCastingOpPattern>(context, + typeConverter); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -844,6 +844,18 @@ return success(); } +static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, + Value lhs, Value rhs) { + assert(lhs.getType() == rhs.getType()); + + Type boolType = builder.getI1Type(); + if (auto vecType = lhs.getType().dyn_cast()) + boolType = VectorType::get(vecType.getShape(), boolType); + state.addTypes(boolType); + + state.addOperands({lhs, rhs}); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -22,12 +22,23 @@ %2 = muli %lhs, %rhs: i32 // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 %3 = divi_signed %lhs, %rhs: i32 - // CHECK: spv.SRem %{{.*}}, %{{.*}}: i32 - %4 = remi_signed %lhs, %rhs: i32 // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 - %5 = divi_unsigned %lhs, %rhs: i32 + %4 = divi_unsigned %lhs, %rhs: i32 // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 - %6 = remi_unsigned %lhs, %rhs: i32 + %5 = remi_unsigned %lhs, %rhs: i32 + return +} + +// CHECK-LABEL: @scalar_srem +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func @scalar_srem(%lhs: i32, %rhs: i32) { + // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32 + // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32 + // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 + // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32 + // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32 + // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 + %0 = remi_signed %lhs, %rhs: i32 return } @@ -75,13 +86,24 @@ // Check int vector types. // CHECK-LABEL: @int_vector234 -func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { +func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) { // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8> %0 = divi_signed %arg0, %arg0: vector<2xi8> - // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi16> - %1 = remi_signed %arg1, %arg1: vector<3xi16> // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64> - %2 = divi_unsigned %arg2, %arg2: vector<4xi64> + %1 = divi_unsigned %arg1, %arg1: vector<4xi64> + return +} + +// CHECK-LABEL: @vector_srem +// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>) +func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) { + // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : vector<3xi16> + // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : vector<3xi16> + // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16> + // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16> + // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16> + // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16> + %0 = remi_signed %arg0, %arg1: vector<3xi16> return } @@ -132,8 +154,8 @@ func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> %0 = divi_signed %arg0, %arg0: vector<2xi8> - // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32> - %1 = remi_signed %arg1, %arg1: vector<3xi16> + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<3xi32> + %1 = divi_signed %arg1, %arg1: vector<3xi16> return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir @@ -71,6 +71,11 @@ %0 = spv.SMod %arg0, %arg1 : vector<4xi32> spv.Return } + spv.func @snegate(%arg0 : vector<4xi32>) "None" { + // CHECK: {{%.*}} = spv.SNegate {{%.*}} : vector<4xi32> + %0 = spv.SNegate %arg0 : vector<4xi32> + spv.Return + } spv.func @srem(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.SRem {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.SRem %arg0, %arg1 : vector<4xi32> diff --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir @@ -174,6 +174,17 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.SNegate +//===----------------------------------------------------------------------===// + +func @Snegate_scalar(%arg: i32) -> i32 { + // CHECK: spv.SNegate + %0 = spv.SNegate %arg : i32 + return %0 : i32 +} + +// ----- //===----------------------------------------------------------------------===// // spv.SRem //===----------------------------------------------------------------------===//