diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -26,6 +26,14 @@ def : BinaryOpPattern; def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; def : BinaryOpPattern; def : BinaryOpPattern; def : BinaryOpPattern; diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -95,6 +95,54 @@ } //===----------------------------------------------------------------------===// +// std bit ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @bitwise_scalar +func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.BitwiseAnd + %0 = and %arg0, %arg1 : i32 + // CHECK: spv.BitwiseOr + %1 = or %arg0, %arg1 : i32 + // CHECK: spv.BitwiseXor + %2 = xor %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @bitwise_vector +func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spv.BitwiseAnd + %0 = and %arg0, %arg1 : vector<4xi32> + // CHECK: spv.BitwiseOr + %1 = or %arg0, %arg1 : vector<4xi32> + // CHECK: spv.BitwiseXor + %2 = xor %arg0, %arg1 : vector<4xi32> + return +} + +// CHECK-LABEL: @shift_scalar +func @shift_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.ShiftLeftLogical + %0 = shift_left %arg0, %arg1 : i32 + // CHECK: spv.ShiftRightArithmetic + %1 = shift_right_signed %arg0, %arg1 : i32 + // CHECK: spv.ShiftRightLogical + %2 = shift_right_unsigned %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @shift_vector +func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spv.ShiftLeftLogical + %0 = shift_left %arg0, %arg1 : vector<4xi32> + // CHECK: spv.ShiftRightArithmetic + %1 = shift_right_signed %arg0, %arg1 : vector<4xi32> + // CHECK: spv.ShiftRightLogical + %2 = shift_right_unsigned %arg0, %arg1 : vector<4xi32> + return +} + +//===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// @@ -156,24 +204,6 @@ return } -// CHECK-LABEL: @logical_scalar_fail -func @logical_scalar_fail(%arg0 : i32, %arg1 : i32) { - // CHECK-NOT: spv.LogicalAnd - %0 = and %arg0, %arg1 : i32 - // CHECK-NOT: spv.LogicalOr - %1 = or %arg0, %arg1 : i32 - return -} - -// CHECK-LABEL: @logical_vector_fail -func @logical_vector_fail(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { - // CHECK-NOT: spv.LogicalAnd - %0 = and %arg0, %arg1 : vector<4xi32> - // CHECK-NOT: spv.LogicalOr - %1 = or %arg0, %arg1 : vector<4xi32> - return -} - //===----------------------------------------------------------------------===// // std.fpext //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir @@ -31,6 +31,15 @@ %0 = spv.Not %arg : i32 spv.ReturnValue %0 : i32 } + func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.BitwiseAnd + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + // CHECK: spv.BitwiseOr + %1 = spv.BitwiseOr %arg0, %arg1 : i32 + // CHECK: spv.BitwiseXor + %2 = spv.BitwiseXor %arg0, %arg1 : i32 + spv.Return + } func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 { // CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16 %0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16