diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -16,15 +16,17 @@ include "mlir/Dialect/StandardOps/Ops.td" include "mlir/Dialect/SPIRV/SPIRVOps.td" -class BinaryOpPattern : - Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r), +class BinaryOpPattern : + Pat<(src SPV_ScalarOrVectorOf:$l, SPV_ScalarOrVectorOf:$r), (tgt $l, $r)>; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; // Constant Op // TODO(ravishankarm): Handle lowering other constant types. diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -134,6 +134,46 @@ return } +//===----------------------------------------------------------------------===// +// std logical binary operations +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @logical_scalar +func @logical_scalar(%arg0 : i1, %arg1 : i1) { + // CHECK: spv.LogicalAnd + %0 = and %arg0, %arg1 : i1 + // CHECK: spv.LogicalOr + %1 = or %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @logical_vector +func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spv.LogicalAnd + %0 = and %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalOr + %1 = or %arg0, %arg1 : vector<4xi1> + return +} + +// CHECK-LABEL: @logical_scalar_fail +func @logical_scalar_fail(%arg0 : i32, %arg1 : i32) { + // CHECK-NOT: spv.LogicalAnd + %0 = and %arg0, %arg1 : i32 + // CHECK-NOT: spv.LogicalOr + %1 = or %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @logical_vector_fail +func @logical_vector_fail(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK-NOT: spv.LogicalAnd + %0 = and %arg0, %arg1 : vector<4xi32> + // CHECK-NOT: spv.LogicalOr + %1 = or %arg0, %arg1 : vector<4xi32> + return +} + //===----------------------------------------------------------------------===// // std.select //===----------------------------------------------------------------------===//