diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -723,6 +723,7 @@ %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1> ``` }]; + let hasFolder = true; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -251,6 +251,23 @@ return Attribute(); } +//===----------------------------------------------------------------------===// +// spirv.LogicalNotEqualOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && + "spirv.LogicalNotEqual should take two operands"); + + if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { + // x && false = x + if (!rhs.value()) + return getOperand1(); + } + + return Attribute(); +} + //===----------------------------------------------------------------------===// // spirv.LogicalNot //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -470,6 +470,22 @@ spirv.ReturnValue %3 : vector<3xi1> } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.LogicalNotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @convert_logical_not_equal_false +// CHECK-SAME: %[[ARG:.+]]: vector<4xi1> +func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> { + %cst = spirv.Constant dense : vector<4xi1> + // CHECK: spirv.ReturnValue %[[ARG]] : vector<4xi1> + %0 = spirv.LogicalNotEqual %arg, %cst : vector<4xi1> + spirv.ReturnValue %0 : vector<4xi1> +} + // ----- func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {