diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -53,7 +53,32 @@ SPIRV_ScalarOrVectorOrCoopMatrixOf:$result ); let assemblyFormat = "operands attr-dict `:` type($result)"; - } +} + +class SPIRV_ArithmeticExtendedBinaryOp traits = []> : + // Result type is a struct with two operand-typed elements. + SPIRV_BinaryOp { + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand1, + SPIRV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPIRV_AnyStruct:$result + ); + + let builders = [ + OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{ + build($_builder, $_state, + ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}), + operand1, operand2); + }]> + ]; + + // These op require a custom verifier. + let hasVerifier = 1; +} // ----- @@ -321,9 +346,8 @@ // ----- -def SPIRV_IAddCarryOp : SPIRV_BinaryOp<"IAddCarry", - SPIRV_AnyStruct, SPIRV_Integer, - [Commutative, Pure]> { +def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry", + [Commutative, Pure]> { let summary = [{ Integer addition of Operand 1 and Operand 2, including the carry. }]; @@ -355,25 +379,6 @@ %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> ``` }]; - - let arguments = (ins - SPIRV_ScalarOrVectorOf:$operand1, - SPIRV_ScalarOrVectorOf:$operand2 - ); - - let results = (outs - SPIRV_AnyStruct:$result - ); - - let builders = [ - OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{ - build($_builder, $_state, - ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}), - operand1, operand2); - }]> - ]; - - let hasVerifier = 1; } // ----- @@ -418,6 +423,75 @@ // ----- +def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended", + [Pure, Commutative]> { + let summary = [{ + Result is the full value of the signed integer multiplication of Operand + 1 and Operand 2. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as signed integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits of the multiplication. + + Member 1 of the result gets the high-order bits of the multiplication. + + + + #### Example: + + ```mlir + %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)> + %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; +} + +// ----- + +def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended", + [Pure, Commutative]> { + let summary = [{ + Result is the full value of the unsigned integer multiplication of + Operand 1 and Operand 2. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type, whose Signedness operand is + 0. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as unsigned integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits of the multiplication. + + Member 1 of the result gets the high-order bits of the multiplication. + + + + #### Example: + + ```mlir + %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)> + %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; +} + +// ----- + def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub", SPIRV_Integer, [UsableInSpecConstantOp]> { @@ -458,8 +532,8 @@ // ----- -def SPIRV_ISubBorrowOp : SPIRV_BinaryOp<"ISubBorrow", SPIRV_AnyStruct, SPIRV_Integer, - [Pure]> { +def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow", + [Pure]> { let summary = [{ Result is the unsigned integer subtraction of Operand 2 from Operand 1, and what it needed to borrow. @@ -494,25 +568,6 @@ %2 = spirv.ISubBorrow %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> ``` }]; - - let arguments = (ins - SPIRV_ScalarOrVectorOf:$operand1, - SPIRV_ScalarOrVectorOf:$operand2 - ); - - let results = (outs - SPIRV_AnyStruct:$result - ); - - let builders = [ - OpBuilder<(ins "Value":$operand1, "Value":$operand2), [{ - build($_builder, $_state, - ::mlir::spirv::StructType::get({operand1.getType(), operand1.getType()}), - operand1, operand2); - }]> - ]; - - let hasVerifier = 1; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4243,6 +4243,8 @@ def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>; def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>; +def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>; +def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>; def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; @@ -4372,17 +4374,17 @@ SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, SPIRV_OC_OpImageDrefGather, SPIRV_OC_OpImage, SPIRV_OC_OpImageQuerySize, SPIRV_OC_OpConvertFToU, SPIRV_OC_OpConvertFToS, SPIRV_OC_OpConvertSToF, SPIRV_OC_OpConvertUToF, - SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric, + SPIRV_OC_OpUConvert, SPIRV_OC_OpSConvert, SPIRV_OC_OpFConvert, SPIRV_OC_OpPtrCastToGeneric, SPIRV_OC_OpGenericCastToPtr, SPIRV_OC_OpGenericCastToPtrExplicit, SPIRV_OC_OpBitcast, SPIRV_OC_OpSNegate, SPIRV_OC_OpFNegate, SPIRV_OC_OpIAdd, SPIRV_OC_OpFAdd, SPIRV_OC_OpISub, SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv, SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, - SPIRV_OC_OpISubBorrow, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, - SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, - SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect, - SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, + SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan, + SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, + SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, + SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, SPIRV_OC_OpSGreaterThan, SPIRV_OC_OpUGreaterThanEqual, SPIRV_OC_OpSGreaterThanEqual, SPIRV_OC_OpULessThan, SPIRV_OC_OpSLessThan, SPIRV_OC_OpULessThanEqual, SPIRV_OC_OpSLessThanEqual, SPIRV_OC_OpFOrdEqual, SPIRV_OC_OpFUnordEqual, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" @@ -31,6 +32,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" +#include #include using namespace mlir; @@ -763,6 +765,53 @@ isa(block.front()); } +template +static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { + auto resultType = op.getType().template cast(); + if (resultType.getNumElements() != 2) + return op.emitOpError("expected result struct type containing two members"); + + if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), + resultType.getElementType(0), + resultType.getElementType(1)})) + return op.emitOpError( + "expected all operand types and struct member types are the same"); + + return success(); +} + +static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operands; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOperandList(operands) || parser.parseColon()) + return failure(); + + Type resultType; + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseType(resultType)) + return failure(); + + auto structType = resultType.dyn_cast(); + if (!structType || structType.getNumElements() != 2) + return parser.emitError(loc, "expected spirv.struct type with two members"); + + SmallVector operandTypes(2, structType.getElementType(0)); + if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) + return failure(); + + result.addTypes(resultType); + return success(); +} + +static void printArithmeticExtendedBinaryOp(Operation *op, + OpAsmPrinter &printer) { + printer << ' '; + printer.printOptionalAttrDict(op->getAttrs()); + printer.printOperands(op->getOperands()); + printer << " : " << op->getResultTypes().front(); +} + //===----------------------------------------------------------------------===// // Common parsers and printers //===----------------------------------------------------------------------===// @@ -2990,48 +3039,16 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::IAddCarryOp::verify() { - auto resultType = getType().cast(); - if (resultType.getNumElements() != 2) - return emitOpError("expected result struct type containing two members"); - - if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(), - resultType.getElementType(0), - resultType.getElementType(1)})) - return emitOpError( - "expected all operand types and struct member types are the same"); - - return success(); + return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operands; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOperandList(operands) || parser.parseColon()) - return failure(); - - Type resultType; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseType(resultType)) - return failure(); - - auto structType = resultType.dyn_cast(); - if (!structType || structType.getNumElements() != 2) - return parser.emitError(loc, "expected spirv.struct type with two members"); - - SmallVector operandTypes(2, structType.getElementType(0)); - if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) - return failure(); - - result.addTypes(resultType); - return success(); + return ::parseArithmeticExtendedBinaryOp(parser, result); } void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { - printer << ' '; - printer.printOptionalAttrDict((*this)->getAttrs()); - printer.printOperands((*this)->getOperands()); - printer << " : " << getType(); + ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// @@ -3039,48 +3056,50 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ISubBorrowOp::verify() { - auto resultType = getType().cast(); - if (resultType.getNumElements() != 2) - return emitOpError("expected result struct type containing two members"); - - if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(), - resultType.getElementType(0), - resultType.getElementType(1)})) - return emitOpError( - "expected all operand types and struct member types are the same"); - - return success(); + return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operands; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseOperandList(operands) || parser.parseColon()) - return failure(); + return ::parseArithmeticExtendedBinaryOp(parser, result); +} - Type resultType; - auto loc = parser.getCurrentLocation(); - if (parser.parseType(resultType)) - return failure(); +void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { + ::printArithmeticExtendedBinaryOp(*this, printer); +} - auto structType = resultType.dyn_cast(); - if (!structType || structType.getNumElements() != 2) - return parser.emitError(loc, "expected spirv.struct type with two members"); +//===----------------------------------------------------------------------===// +// spirv.SMulExtended +//===----------------------------------------------------------------------===// - SmallVector operandTypes(2, structType.getElementType(0)); - if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) - return failure(); +LogicalResult spirv::SMulExtendedOp::verify() { + return ::verifyArithmeticExtendedBinaryOp(*this); +} - result.addTypes(resultType); - return success(); +ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseArithmeticExtendedBinaryOp(parser, result); } -void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { - printer << ' '; - printer.printOptionalAttrDict((*this)->getAttrs()); - printer.printOperands((*this)->getOperands()); - printer << " : " << getType(); +void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { + ::printArithmeticExtendedBinaryOp(*this, printer); +} + +//===----------------------------------------------------------------------===// +// spirv.UMulExtended +//===----------------------------------------------------------------------===// + +LogicalResult spirv::UMulExtendedOp::verify() { + return ::verifyArithmeticExtendedBinaryOp(*this); +} + +ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseArithmeticExtendedBinaryOp(parser, result); +} + +void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { + ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -254,6 +254,110 @@ // ----- +//===----------------------------------------------------------------------===// +// spirv.SMulExtended +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @smul_extended_scalar +func.func @smul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> { + // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)> + %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32)> + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @smul_extended_vector +func.func @smul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> { + // CHECK: spirv.SMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + +func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> { + // expected-error @+1 {{expected spirv.struct type with two members}} + %0 = spirv.SMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)> + return %0 : !spirv.struct<(i32, i32, i32)> +} + +// ----- + +func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32)> { + // expected-error @+1 {{expected result struct type containing two members}} + %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)> + return %0 : !spirv.struct<(i32)> +} + +// ----- + +func.func @smul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spirv.SMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)> + return %0 : !spirv.struct<(i32, i64)> +} + +// ----- + +func.func @smul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spirv.SMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)> + return %0 : !spirv.struct<(i32, i32)> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.UMulExtended +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umul_extended_scalar +func.func @umul_extended_scalar(%arg: i32) -> !spirv.struct<(i32, i32)> { + // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(i32, i32)> + %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32)> + return %0 : !spirv.struct<(i32, i32)> +} + +// CHECK-LABEL: @umul_extended_vector +func.func @umul_extended_vector(%arg: vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> { + // CHECK: spirv.UMulExtended %{{.+}}, %{{.+}} : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(vector<3xi32>, vector<3xi32>)> + return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + +func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i32, i32)> { + // expected-error @+1 {{expected spirv.struct type with two members}} + %0 = spirv.UMulExtended %arg, %arg : !spirv.struct<(i32, i32, i32)> + return %0 : !spirv.struct<(i32, i32, i32)> +} + +// ----- + +func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32)> { + // expected-error @+1 {{expected result struct type containing two members}} + %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32)> + return %0 : !spirv.struct<(i32)> +} + +// ----- + +func.func @umul_extended(%arg: i32) -> !spirv.struct<(i32, i64)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spirv.UMulExtended"(%arg, %arg): (i32, i32) -> !spirv.struct<(i32, i64)> + return %0 : !spirv.struct<(i32, i64)> +} + +// ----- + +func.func @umul_extended(%arg: i64) -> !spirv.struct<(i32, i32)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spirv.UMulExtended"(%arg, %arg): (i64, i64) -> !spirv.struct<(i32, i32)> + return %0 : !spirv.struct<(i32, i32)> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.SDiv //===----------------------------------------------------------------------===//