diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -287,6 +287,32 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertBitwiseBinary +//===----------------------------------------------------------------------===// + +/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`. +template +struct ConvertBitwiseBinary final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto newTy = this->getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "unsupported type"); + + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMulI //===----------------------------------------------------------------------===// @@ -694,6 +720,9 @@ ConvertConstant, ConvertVectorPrint, // Binary ops. ConvertAddI, ConvertMulI, ConvertShRUI, + // Bitwise binary ops. + ConvertBitwiseBinary, ConvertBitwiseBinary, + ConvertBitwiseBinary, // Extension and truncation ops. ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -331,3 +331,57 @@ %m = arith.shrui %a, %b : vector<3xi64> return %m : vector<3xi64> } + +// CHECK-LABEL: func @andi_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.andi [[ARG0]], [[ARG1]] : vector<2xi32> +// CHECK-NEXT: return [[RES]] : vector<2xi32> +func.func @andi_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.andi %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @andi_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.andi [[ARG0]], [[ARG1]] : vector<3x2xi32> +// CHECK: return [[RES]] : vector<3x2xi32> +func.func @andi_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.andi %a, %b : vector<3xi64> + return %x : vector<3xi64> +} + +// CHECK-LABEL: func @ori_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.ori [[ARG0]], [[ARG1]] : vector<2xi32> +// CHECK-NEXT: return [[RES]] : vector<2xi32> +func.func @ori_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.ori %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @ori_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.ori [[ARG0]], [[ARG1]] : vector<3x2xi32> +// CHECK: return [[RES]] : vector<3x2xi32> +func.func @ori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.ori %a, %b : vector<3xi64> + return %x : vector<3xi64> +} + +// CHECK-LABEL: func @xori_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.xori [[ARG0]], [[ARG1]] : vector<2xi32> +// CHECK-NEXT: return [[RES]] : vector<2xi32> +func.func @xori_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.xori %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @xori_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.xori [[ARG0]], [[ARG1]] : vector<3x2xi32> +// CHECK: return [[RES]] : vector<3x2xi32> +func.func @xori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.xori %a, %b : vector<3xi64> + return %x : vector<3xi64> +}