diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -526,6 +526,8 @@ %2 = spv.LogicalAnd %0, %1 : vector<4xi1> ``` }]; + + let hasFolder = 1; } // ----- @@ -656,6 +658,8 @@ %2 = spv.LogicalOr %0, %1 : vector<4xi1> ``` }]; + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -24,6 +24,26 @@ // Common utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given `irVal` is a scalar or splat vector constant of +/// the given `boolVal`. +static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) { + if (!boolAttr) + return false; + + auto type = boolAttr.getType(); + if (type.isInteger(1)) { + auto attr = boolAttr.cast(); + return attr.getValue() == boolVal; + } + if (auto vecType = type.cast()) { + if (vecType.getElementType().isInteger(1)) + if (auto attr = boolAttr.dyn_cast()) + return attr.getSplatValue().template cast().getValue() == + boolVal; + } + return false; +} + // Extracts an element from the given `composite` by following the given // `indices`. Returns a null Attribute if error happens. static Attribute extractCompositeElement(Attribute composite, @@ -187,6 +207,24 @@ [](APInt a, APInt b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// spv.LogicalAnd +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::LogicalAndOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.LogicalAnd should take two operands"); + + // x && true = x + if (isScalarOrSplatBoolAttr(operands.back(), true)) + return operand1(); + + // x && false = false + if (isScalarOrSplatBoolAttr(operands.back(), false)) + return operands.back(); + + return Attribute(); +} + //===----------------------------------------------------------------------===// // spv.LogicalNot //===----------------------------------------------------------------------===// @@ -198,6 +236,24 @@ ConvertLogicalNotOfLogicalNotEqual>(context); } +//===----------------------------------------------------------------------===// +// spv.LogicalOr +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::LogicalOrOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.LogicalOr should take two operands"); + + // x || true = true + if (isScalarOrSplatBoolAttr(operands.back(), true)) + return operands.back(); + + // x || false = x + if (isScalarOrSplatBoolAttr(operands.back(), false)) + return operand1(); + + return Attribute(); +} + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -362,6 +362,36 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.LogicalAnd +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @convert_logical_and_true_false_scalar +// CHECK-SAME: %[[ARG:.+]]: i1 +func @convert_logical_and_true_false_scalar(%arg: i1) -> (i1, i1) { + %true = spv.constant true + // CHECK: %[[FALSE:.+]] = spv.constant false + %false = spv.constant false + %0 = spv.LogicalAnd %true, %arg: i1 + %1 = spv.LogicalAnd %arg, %false: i1 + // CHECK: return %[[ARG]], %[[FALSE]] + return %0, %1: i1, i1 +} + +// CHECK-LABEL: @convert_logical_and_true_false_vector +// CHECK-SAME: %[[ARG:.+]]: vector<3xi1> +func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) { + %true = spv.constant dense : vector<3xi1> + // CHECK: %[[FALSE:.+]] = spv.constant dense + %false = spv.constant dense : vector<3xi1> + %0 = spv.LogicalAnd %true, %arg: vector<3xi1> + %1 = spv.LogicalAnd %arg, %false: vector<3xi1> + // CHECK: return %[[ARG]], %[[FALSE]] + return %0, %1: vector<3xi1>, vector<3xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.LogicalNot //===----------------------------------------------------------------------===// @@ -419,6 +449,36 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.LogicalOr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @convert_logical_or_true_false_scalar +// CHECK-SAME: %[[ARG:.+]]: i1 +func @convert_logical_or_true_false_scalar(%arg: i1) -> (i1, i1) { + // CHECK: %[[TRUE:.+]] = spv.constant true + %true = spv.constant true + %false = spv.constant false + %0 = spv.LogicalOr %true, %arg: i1 + %1 = spv.LogicalOr %arg, %false: i1 + // CHECK: return %[[TRUE]], %[[ARG]] + return %0, %1: i1, i1 +} + +// CHECK-LABEL: @convert_logical_or_true_false_vector +// CHECK-SAME: %[[ARG:.+]]: vector<3xi1> +func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) { + // CHECK: %[[TRUE:.+]] = spv.constant dense + %true = spv.constant dense : vector<3xi1> + %false = spv.constant dense : vector<3xi1> + %0 = spv.LogicalOr %true, %arg: vector<3xi1> + %1 = spv.LogicalOr %arg, %false: vector<3xi1> + // CHECK: return %[[TRUE]], %[[ARG]] + return %0, %1: vector<3xi1>, vector<3xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===//