diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td @@ -187,4 +187,154 @@ }]; } +// ----- + +def SPIRV_SDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SDotAccSat", + [SignedOp]> { + let summary = [{ + Signed integer dot product of Vector 1 and Vector 2 and signed + saturating addition of the result with Accumulator. + }]; + + let description = [{ + Result Type must be an integer type whose Width must be greater than or + equal to that of the components of Vector 1 and Vector 2. + + Vector 1 and Vector 2 must have the same type. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type + (enabled by the DotProductInput4x8Bit or DotProductInputAll capability). + + The type of Accumulator must be the same as Result Type. + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of the input vectors are sign-extended to the bit width + of the result's type. The sign-extended input vectors are then + multiplied component-wise and all components of the vector resulting + from the component-wise multiplication are added together. Finally, the + resulting sum is added to the input accumulator. This final addition is + saturating. + + If any of the multiplications or additions, with the exception of the + final accumulation, overflow or underflow, the result of the instruction + is undefined. + + + + #### Example: + + ```mlir + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + ``` + }]; +} + +// ----- + +def SPIRV_SUDotAccSatOp : SPIRV_IntegerDotProductTernaryOp<"SUDotAccSat", + [SignedOp, + UnsignedOp]> { + let summary = [{ + Mixed-signedness integer dot product of Vector 1 and Vector 2 and signed + saturating addition of the result with Accumulator. Components of Vector + 1 are treated as signed, components of Vector 2 are treated as unsigned. + }]; + + let description = [{ + Result Type must be an integer type whose Width must be greater than or + equal to that of the components of Vector 1 and Vector 2. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type with + the same number of components and same component Width (enabled by the + DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1 + and Vector 2 are vectors, the components of Vector 2 must have a + Signedness of 0. + + The type of Accumulator must be the same as Result Type. + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of Vector 1 are sign-extended to the bit width of the + result's type. All components of Vector 2 are zero-extended to the bit + width of the result's type. The sign- or zero-extended input vectors are + then multiplied component-wise and all components of the vector + resulting from the component-wise multiplication are added together. + Finally, the resulting sum is added to the input accumulator. This final + addition is saturating. + + If any of the multiplications or additions, with the exception of the + final accumulation, overflow or underflow, the result of the instruction + is undefined. + + + + #### Example: + + ```mlir + %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + ``` + }]; +} + +// ----- + +def SPIRV_UDotAccSatOp : + SPIRV_IntegerDotProductTernaryOp<"UDotAccSat", [UnsignedOp]> { + let summary = [{ + Unsigned integer dot product of Vector 1 and Vector 2 and unsigned + saturating addition of the result with Accumulator. + }]; + + let description = [{ + Result Type must be an integer type with Signedness of 0 whose Width + must be greater than or equal to that of the components of Vector 1 and + Vector 2. + + Vector 1 and Vector 2 must have the same type. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type with + Signedness of 0 (enabled by the DotProductInput4x8Bit or + DotProductInputAll capability). + + The type of Accumulator must be the same as Result Type. + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of the input vectors are zero-extended to the bit width + of the result's type. The zero-extended input vectors are then + multiplied component-wise and all components of the vector resulting + from the component-wise multiplication are added together. Finally, the + resulting sum is added to the input accumulator. This final addition is + saturating. + + If any of the multiplications or additions, with the exception of the + final accumulation, overflow or underflow, the result of the instruction + is undefined. + + + + #### Example: + + ```mlir + %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + ``` + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4829,6 +4829,11 @@ "op only supports the 'format' #spirv.packed_vector_format attribute"); Type resultTy = op->getResultTypes().front(); + bool hasAccumulator = op->getNumOperands() == 3; + if (hasAccumulator && op->getOperand(2).getType() != resultTy) + return op->emitOpError( + "requires the same accumulator operand and result types"); + unsigned factorBitWidth = getBitWidth(factorTy); unsigned resultBitWidth = getBitWidth(resultTy); if (factorBitWidth > resultBitWidth) @@ -4909,6 +4914,9 @@ SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp) SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp) SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp) #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -143,3 +143,93 @@ %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64 return %r: i64 } + +// CHECK-LABEL: sdot_acc_sat_scalar_i32_i32 +func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.SDotAccSat %a, %a, %a {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: sdot_acc_sat_vector_4xi8_i64 +func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sdot_acc_sat_vector_4xi16_i64 +func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.SDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sudot_acc_sat_scalar_i32_i32 +func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.SUDotAccSat %a, %a, %a {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: sudot_acc_sat_vector_4xi8_i64 +func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sudot_acc_sat_vector_4xi16_i64 +func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.SUDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r: i64 +} + +// CHECK-LABEL: udot_acc_sat_scalar_i32_i32 +func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.UDotAccSat %a, %a, %a {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: udot_acc_sat_vector_4xi8_i64 +func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi8>, vector<4xi8>, i64) -> i64 + return %r: i64 +} + +// CHECK-LABEL: udot_acc_sat_vector_4xi16_i64 +func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.UDotAccSat %a, %a, %acc: (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r: i64 +} diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir @@ -142,3 +142,158 @@ %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 return %r : i32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SDotAccSat +//===----------------------------------------------------------------------===// + +// CHECK: @sdot_acc_sat_scalar_i32 +func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.SDotAccSat + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @sdot_acc_sat_scalar_i64 +func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SDotAccSat + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + return %r : i64 +} + +// CHECK: @sdot_acc_sat_vector_4xi8 +func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.SDotAccSat + %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + return %r : i32 +} + +// CHECK: @sdot_acc_sat_vector_4xi16 +func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SDotAccSat + %r = spirv.SDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r : i64 +} + +// CHECK: @sdot_acc_sat_vector_8xi8 +func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SDotAccSat + %r = spirv.SDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64 + return %r : i64 +} + +// ----- + +func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc: i32) -> i32 { + // expected-error @+1 {{op requires the same type for both vector operands}} + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i64, i32) -> i32 + return %r : i32 +} + +// ----- + +func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc: i16) -> i16 { + // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}} + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i16) -> i16 + return %r : i16 +} + +// ----- + +func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc: i64) -> i64 { + // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}} + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i64, i64, i64) -> i64 + return %r : i64 +} + +// ----- + +func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc: i32) -> i64 { + // expected-error @+1 {{requires the same accumulator operand and result types}} + %r = spirv.SDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i64 + return %r : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SUDotAccSat +//===----------------------------------------------------------------------===// + +// CHECK: @sudot_acc_sat_scalar_i32 +func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.SUDotAccSat + %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @sudot_acc_sat_scalar_i64 +func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SUDotAccSat + %r = spirv.SUDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + return %r : i64 +} + +// CHECK: @sudot_acc_sat_vector_4xi8 +func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.SUDotAccSat + %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + return %r : i32 +} + +// CHECK: @sudot_acc_sat_vector_4xi16 +func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SUDotAccSat + %r = spirv.SUDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r : i64 +} + +// CHECK: @sudot_acc_sat_vector_8xi8 +func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.SUDotAccSat + %r = spirv.SUDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64 + return %r : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.UDotAccSat +//===----------------------------------------------------------------------===// + +// CHECK: @udot_acc_sat_scalar_i32 +func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.UDotAccSat + %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @udot_acc_sat_scalar_i64 +func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.UDotAccSat + %r = spirv.UDotAccSat %a, %b, %acc {format = #spirv.packed_vector_format}: (i32, i32, i64) -> i64 + return %r : i64 +} + +// CHECK: @udot_acc_sat_vector_4xi8 +func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc: i32) -> i32 { + // CHECK-NEXT: spirv.UDotAccSat + %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi8>, vector<4xi8>, i32) -> i32 + return %r : i32 +} + +// CHECK: @udot_acc_sat_vector_4xi16 +func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.UDotAccSat + %r = spirv.UDotAccSat %a, %b, %acc : (vector<4xi16>, vector<4xi16>, i64) -> i64 + return %r : i64 +} + +// CHECK: @udot_acc_sat_vector_8xi8 +func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc: i64) -> i64 { + // CHECK-NEXT: spirv.UDotAccSat + %r = spirv.UDotAccSat %a, %b, %acc : (vector<8xi8>, vector<8xi8>, i64) -> i64 + return %r : i64 +} diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir --- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir @@ -212,6 +212,77 @@ } { // CHECK: test.convert_to_udot_op %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0 : i64 +} + +// CHECK-LABEL: @sdot_acc_sat_scalar_i32_i32_capabilities +func.func @sdot_acc_sat_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.SDotAccSat + %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %operand) + {format = #spirv.packed_vector_format}: (i32, i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sudot_acc_sat_vector_4xi8_i32_capabilities +func.func @sudot_acc_sat_vector_4xi8_i32_capabilities(%operand: vector<4xi8>, %acc: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.SUDotAccSat + %0 = "test.convert_to_sudot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability1 +func.func @udot_acc_sat_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>, %acc: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_acc_sat_op + %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @udot_acc_sat_vector_4xi8_i32_missing_capability2 +func.func @udot_acc_sat_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>, %acc: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_acc_sat_op + %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi8>, vector<4xi8>, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_capabilities +func.func @udot_acc_sat_vector_4xi16_i64_capabilities(%operand: vector<4xi16>, %acc: i64) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.UDotAccSat + %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability1 +func.func @udot_acc_sat_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>, %acc: i64) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_acc_sat_op + %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @udot_acc_sat_vector_4xi16_i64_missing_capability2 +func.func @udot_acc_sat_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>, %acc: i64) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_acc_sat_op + %0 = "test.convert_to_udot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64) return %0: i64 } @@ -304,3 +375,25 @@ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) return %0: i64 } + +// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_implied_extension +func.func @sdot_acc_sat_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes { + // Version 1.6 implies SPV_KHR_integer_to_product. + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.SDotAccSat + %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @sdot_acc_sat_vector_4xi16_i64_missing_extension +func.func @sdot_acc_sat_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>, %acc: i64) -> i64 attributes { + // Version 1.5 does not imply SPV_KHR_integer_to_product. + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_sdot_acc_sat_op + %0 = "test.convert_to_sdot_acc_sat_op"(%operand, %operand, %acc): (vector<4xi16>, vector<4xi16>, i64) -> (i64) + return %0: i64 +} diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -223,14 +223,24 @@ static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op"; static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op"; static constexpr char uDotTestOpName[] = "test.convert_to_udot_op"; + static constexpr char sDotAccSatTestOpName[] = + "test.convert_to_sdot_acc_sat_op"; + static constexpr char suDotAccSatTestOpName[] = + "test.convert_to_sudot_acc_sat_op"; + static constexpr char uDotAccSatTestOpName[] = + "test.convert_to_udot_acc_sat_op"; RewritePatternSet patterns(context); - patterns.add, - ConvertToIntegerDotProd, - ConvertToIntegerDotProd>(context); + patterns.add< + ConvertToAtomCmpExchangeWeak, ConvertToBitReverse, + ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd>( + context); if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) return signalPassFailure();