diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -310,6 +310,55 @@ // ----- +def SPV_IAddCarryOp : SPV_BinaryOp<"IAddCarry", + SPV_AnyStruct, SPV_Integer, + [Commutative, NoSideEffect]> { + let summary = [{ + Integer addition of Operand 1 and Operand 2, including the carry. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type, whose Signedness operand is + 0. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as unsigned integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits (full component width) of + the addition. + + Member 1 of the result gets the high-order (carry) bit of the result of + the addition. That is, it gets the value 1 if the addition overflowed + the component width, and 0 otherwise. + + + + #### Example: + + ```mlir + %2 = spv.IAddCarry %0, %1 : !spv.struct<(i32, i32)> + %2 = spv.IAddCarry %0, %1 : !spv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOf:$operand1, + SPV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPV_AnyStruct:$result + ); + + let hasVerifier = 1; +} + +// ----- + def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative, UsableInSpecConstantOp]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4094,6 +4094,7 @@ def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; +def SPV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>; def SPV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>; def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; @@ -4219,16 +4220,16 @@ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpISubBorrow, - SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIAddCarry, + SPV_OC_OpISubBorrow, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, + SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -19,7 +19,6 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpDefinition.h" @@ -2840,6 +2839,55 @@ printGroupNonUniformArithmeticOp(*this, p); } +//===----------------------------------------------------------------------===// +// spv.IAddCarryOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::IAddCarryOp::verify() { + auto resultType = getType().cast(); + if (resultType.getNumElements() != 2) + return emitOpError("expected result struct type containing two members"); + + if (!llvm::is_splat(llvm::makeArrayRef( + {operand1().getType(), operand2().getType(), + resultType.getElementType(0), resultType.getElementType(1)}))) + return emitOpError( + "expected all operand types and struct member types are the same"); + + return success(); +} + +ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector operands; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOperandList(operands) || parser.parseColon()) + return failure(); + + Type resultType; + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseType(resultType)) + return failure(); + + auto structType = resultType.dyn_cast(); + if (!structType || structType.getNumElements() != 2) + return parser.emitError(loc, "expected spv.struct type with two members"); + + SmallVector operandTypes(2, structType.getElementType(0)); + if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) + return failure(); + + result.addTypes(resultType); + return success(); +} + +void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { + printer << ' '; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer.printOperands((*this)->getOperands()); + printer << " : " << getType(); +} + //===----------------------------------------------------------------------===// // spv.ISubBorrowOp //===----------------------------------------------------------------------===// @@ -2849,12 +2897,9 @@ if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); - SmallVector types; - types.push_back(operand1().getType()); - types.push_back(operand2().getType()); - types.push_back(resultType.getElementType(0)); - types.push_back(resultType.getElementType(1)); - if (!llvm::is_splat(types)) + if (!llvm::is_splat(llvm::makeArrayRef( + {operand1().getType(), operand2().getType(), + resultType.getElementType(0), resultType.getElementType(1)}))) return emitOpError( "expected all operand types and struct member types are the same"); @@ -2862,9 +2907,9 @@ } ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, - OperationState &state) { + OperationState &result) { SmallVector operands; - if (parser.parseOptionalAttrDict(state.attributes) || + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseOperandList(operands) || parser.parseColon()) return failure(); @@ -2878,10 +2923,10 @@ return parser.emitError(loc, "expected spv.struct type with two members"); SmallVector operandTypes(2, structType.getElementType(0)); - if (parser.resolveOperands(operands, operandTypes, loc, state.operands)) + if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) return failure(); - state.addTypes(resultType); + result.addTypes(resultType); return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -150,6 +150,58 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.IAddCarry +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @iadd_carry_scalar +func.func @iadd_carry_scalar(%arg: i32) -> !spv.struct<(i32, i32)> { + // CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(i32, i32)> + %0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32)> + return %0 : !spv.struct<(i32, i32)> +} + +// CHECK-LABEL: @iadd_carry_vector +func.func @iadd_carry_vector(%arg: vector<3xi32>) -> !spv.struct<(vector<3xi32>, vector<3xi32>)> { + // CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(vector<3xi32>, vector<3xi32>)> + %0 = spv.IAddCarry %arg, %arg : !spv.struct<(vector<3xi32>, vector<3xi32>)> + return %0 : !spv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + +func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i32, i32)> { + // expected-error @+1 {{expected spv.struct type with two members}} + %0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32, i32)> + return %0 : !spv.struct<(i32, i32, i32)> +} + +// ----- + +func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32)> { + // expected-error @+1 {{expected result struct type containing two members}} + %0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32)> + return %0 : !spv.struct<(i32)> +} + +// ----- + +func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i64)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32, i64)> + return %0 : !spv.struct<(i32, i64)> +} + +// ----- + +func.func @iadd_carry(%arg: i64) -> !spv.struct<(i32, i32)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spv.IAddCarry"(%arg, %arg): (i64, i64) -> !spv.struct<(i32, i32)> + return %0 : !spv.struct<(i32, i32)> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ISubBorrow //===----------------------------------------------------------------------===//