diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -287,6 +287,40 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertBitwiseBinary +//===----------------------------------------------------------------------===// + +/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`. +template +struct ConvertBitwiseBinary final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto newTy = this->getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "unsupported type"); + + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + auto [rhsElem0, rhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + Value resElem0 = rewriter.create(loc, lhsElem0, rhsElem0); + Value resElem1 = rewriter.create(loc, lhsElem1, rhsElem1); + Value resultVec = + constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMulI //===----------------------------------------------------------------------===// @@ -694,6 +728,9 @@ ConvertConstant, ConvertVectorPrint, // Binary ops. ConvertAddI, ConvertMulI, ConvertShRUI, + // Bitwise binary ops. + ConvertBitwiseBinary, ConvertBitwiseBinary, + ConvertBitwiseBinary, // Extension and truncation ops. ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -331,3 +331,81 @@ %m = arith.shrui %a, %b : vector<3xi64> return %m : vector<3xi64> } + +// CHECK-LABEL: func @andi_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[RES0:%.+]] = arith.andi [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES1:%.+]] = arith.andi [[HIGH0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @andi_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.andi %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @andi_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK-NEXT: {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{.+}} : vector<3x2xi32> +func.func @andi_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.andi %a, %b : vector<3xi64> + return %x : vector<3xi64> +} + +// CHECK-LABEL: func @ori_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[RES0:%.+]] = arith.ori [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES1:%.+]] = arith.ori [[HIGH0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @ori_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.ori %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @ori_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK-NEXT: {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{.+}} : vector<3x2xi32> +func.func @ori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.ori %a, %b : vector<3xi64> + return %x : vector<3xi64> +} + +// CHECK-LABEL: func @xori_scalar_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[RES0:%.+]] = arith.xori [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES1:%.+]] = arith.xori [[HIGH0]], [[HIGH1]] : i32 +// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @xori_scalar_a_b(%a : i64, %b : i64) -> i64 { + %x = arith.xori %a, %b : i64 + return %x : i64 +} + +// CHECK-LABEL: func @xori_vector_a_b +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK-NEXT: {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{.+}} : vector<3x2xi32> +func.func @xori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %x = arith.xori %a, %b : vector<3xi64> + return %x : vector<3xi64> +}