diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -338,6 +338,8 @@ %2 = spirv.BitwiseAnd %0, %1 : vector<4xi32> ``` }]; + + let hasFolder = 1; } // ----- @@ -373,6 +375,8 @@ %2 = spirv.BitwiseOr %0, %1 : vector<4xi32> ``` }]; + + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/StringExtras.h" #include #include +#include #include using namespace mlir; @@ -1961,6 +1962,72 @@ return verifyShiftOp(*this); } +//===----------------------------------------------------------------------===// +// spirv.BtiwiseAndOp +//===----------------------------------------------------------------------===// + +static std::optional extractIntConstant(Attribute attr) { + IntegerAttr intAttr; + if (auto splat = dyn_cast_if_present(attr)) + intAttr = dyn_cast(splat.getSplatValue()); + else + intAttr = dyn_cast_if_present(attr); + + if (!intAttr) + return std::nullopt; + + return intAttr.getValue(); +} + +OpFoldResult +spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) { + std::optional rhsVal = extractIntConstant(adaptor.getOperand2()); + if (!rhsVal) + return {}; + + APInt rhsMask = *rhsVal; + + // x & 0 -> 0 + if (rhsMask.isZero()) + return getOperand2(); + + // x & -> x + if (rhsMask.isAllOnes()) + return getOperand1(); + + // (UConvert x : iN to iK) & -> UConvert x + if (auto zext = getOperand1().getDefiningOp()) { + int valueBits = + getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth(); + if (rhsMask.zextOrTrunc(valueBits).isAllOnes()) + return getOperand1(); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// spirv.BtiwiseOrOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) { + std::optional rhsVal = extractIntConstant(adaptor.getOperand2()); + if (!rhsVal) + return {}; + + APInt rhsMask = *rhsVal; + + // x | 0 -> x + if (rhsMask.isZero()) + return getOperand1(); + + // x | -> + if (rhsMask.isAllOnes()) + return getOperand2(); + + return {}; +} + //===----------------------------------------------------------------------===// // spirv.ImageQuerySize //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --split-input-file --verify-diagnostics --canonicalize %s \ +// RUN: | FileCheck %s --check-prefix=CANON //===----------------------------------------------------------------------===// // spirv.BitCount @@ -82,18 +84,56 @@ // spirv.BitwiseOr //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @bitwise_or_scalar func.func @bitwise_or_scalar(%arg: i32) -> i32 { // CHECK: spirv.BitwiseOr %0 = spirv.BitwiseOr %arg, %arg : i32 return %0 : i32 } +// CHECK-LABEL: func @bitwise_or_vector func.func @bitwise_or_vector(%arg: vector<4xi32>) -> vector<4xi32> { // CHECK: spirv.BitwiseOr %0 = spirv.BitwiseOr %arg, %arg : vector<4xi32> return %0 : vector<4xi32> } +// CANON-LABEL: func @bitwise_or_zero +// CANON-SAME: (%[[ARG:.+]]: i32) +func.func @bitwise_or_zero(%arg: i32) -> i32 { + // CANON: return %[[ARG]] + %zero = spirv.Constant 0 : i32 + %0 = spirv.BitwiseOr %arg, %zero : i32 + return %0 : i32 +} + +// CANON-LABEL: func @bitwise_or_zero_vector +// CANON-SAME: (%[[ARG:.+]]: vector<4xi32>) +func.func @bitwise_or_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> { + // CANON: return %[[ARG]] + %zero = spirv.Constant dense<0> : vector<4xi32> + %0 = spirv.BitwiseOr %arg, %zero : vector<4xi32> + return %0 : vector<4xi32> +} + +// CANON-LABEL: func @bitwise_or_all_ones +func.func @bitwise_or_all_ones(%arg: i8) -> i8 { + // CANON: %[[CST:.+]] = spirv.Constant -1 + // CANON: return %[[CST]] + %ones = spirv.Constant 255 : i8 + %0 = spirv.BitwiseOr %arg, %ones : i8 + return %0 : i8 +} + +// CANON-LABEL: func @bitwise_or_all_ones_vector +func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { + // CANON: %[[CST:.+]] = spirv.Constant dense<-1> + // CANON: return %[[CST]] + %ones = spirv.Constant dense<255> : vector<3xi8> + %0 = spirv.BitwiseOr %arg, %ones : vector<3xi8> + return %0 : vector<3xi8> +} + // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { @@ -134,18 +174,101 @@ // spirv.BitwiseAnd //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @bitwise_and_scalar func.func @bitwise_and_scalar(%arg: i32) -> i32 { // CHECK: spirv.BitwiseAnd %0 = spirv.BitwiseAnd %arg, %arg : i32 return %0 : i32 } +// CHECK-LABEL: func @bitwise_and_vector func.func @bitwise_and_vector(%arg: vector<4xi32>) -> vector<4xi32> { // CHECK: spirv.BitwiseAnd %0 = spirv.BitwiseAnd %arg, %arg : vector<4xi32> return %0 : vector<4xi32> } +// CANON-LABEL: func @bitwise_and_zero +func.func @bitwise_and_zero(%arg: i32) -> i32 { + // CANON: %[[CST:.+]] = spirv.Constant 0 + // CANON: return %[[CST]] + %zero = spirv.Constant 0 : i32 + %0 = spirv.BitwiseAnd %arg, %zero : i32 + return %0 : i32 +} + +// CANON-LABEL: func @bitwise_and_zero_vector +func.func @bitwise_and_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> { + // CANON: %[[CST:.+]] = spirv.Constant dense<0> + // CANON: return %[[CST]] + %zero = spirv.Constant dense<0> : vector<4xi32> + %0 = spirv.BitwiseAnd %arg, %zero : vector<4xi32> + return %0 : vector<4xi32> +} + +// CANON-LABEL: func @bitwise_and_all_ones +// CANON-SAME: (%[[ARG:.+]]: i8) +func.func @bitwise_and_all_ones(%arg: i8) -> i8 { + // CANON: return %[[ARG]] + %ones = spirv.Constant 255 : i8 + %0 = spirv.BitwiseAnd %arg, %ones : i8 + return %0 : i8 +} + +// CANON-LABEL: func @bitwise_and_all_ones_vector +// CANON-SAME: (%[[ARG:.+]]: vector<3xi8>) +func.func @bitwise_and_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { + // CANON: return %[[ARG]] + %ones = spirv.Constant dense<255> : vector<3xi8> + %0 = spirv.BitwiseAnd %arg, %ones : vector<3xi8> + return %0 : vector<3xi8> +} + +// CANON-LABEL: func @bitwise_and_zext_1 +// CANON-SAME: (%[[ARG:.+]]: i8) +func.func @bitwise_and_zext_1(%arg: i8) -> i32 { + // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]] + // CANON: return %[[ZEXT]] + %zext = spirv.UConvert %arg : i8 to i32 + %ones = spirv.Constant 255 : i32 + %0 = spirv.BitwiseAnd %zext, %ones : i32 + return %0 : i32 +} + +// CANON-LABEL: func @bitwise_and_zext_2 +// CANON-SAME: (%[[ARG:.+]]: i8) +func.func @bitwise_and_zext_2(%arg: i8) -> i32 { + // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]] + // CANON: return %[[ZEXT]] + %zext = spirv.UConvert %arg : i8 to i32 + %ones = spirv.Constant 0x12345ff : i32 + %0 = spirv.BitwiseAnd %zext, %ones : i32 + return %0 : i32 +} + +// CANON-LABEL: func @bitwise_and_zext_3 +// CANON-SAME: (%[[ARG:.+]]: i8) +func.func @bitwise_and_zext_3(%arg: i8) -> i32 { + // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]] + // CANON: %[[AND:.+]] = spirv.BitwiseAnd %[[ZEXT]] + // CANON: return %[[AND]] + %zext = spirv.UConvert %arg : i8 to i32 + %ones = spirv.Constant 254 : i32 + %0 = spirv.BitwiseAnd %zext, %ones : i32 + return %0 : i32 +} + +// CANON-LABEL: func @bitwise_and_zext_vector +// CANON-SAME: (%[[ARG:.+]]: vector<2xi8>) +func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { + // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]] + // CANON: return %[[ZEXT]] + %zext = spirv.UConvert %arg : vector<2xi8> to vector<2xi32> + %ones = spirv.Constant dense<255> : vector<2xi32> + %0 = spirv.BitwiseAnd %zext, %ones : vector<2xi32> + return %0 : vector<2xi32> +} + // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {