diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -491,6 +491,41 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertSelect +//===----------------------------------------------------------------------===// + +struct ConvertSelect final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto newTy = getTypeConverter() + ->convertType(op.getType()) + .dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + + auto [trueElem0, trueElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getTrueValue()); + auto [falseElem0, falseElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getFalseValue()); + Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition()); + + Value resElem0 = + rewriter.create(loc, cond, trueElem0, falseElem0); + Value resElem1 = + rewriter.create(loc, cond, trueElem1, falseElem1); + Value resultVec = + constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertShLI //===----------------------------------------------------------------------===// @@ -828,7 +863,7 @@ // Populate `arith.*` conversion patterns. patterns.add< // Misc ops. - ConvertConstant, ConvertVectorPrint, + ConvertConstant, ConvertVectorPrint, ConvertSelect, // Binary ops. ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI, // Bitwise binary ops. diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -224,6 +224,46 @@ return %b : vector<3xi16> } +// CHECK-LABEL: func.func @select_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>, [[ARG2:%.+]]: i1) +// CHECK-SAME: -> vector<2xi32> +// CHECK-NEXT: [[TLOW:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[THIGH:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[FLOW:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[FHIGH:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> +// CHECK-NEXT: [[SLOW:%.+]] = arith.select [[ARG2]], [[TLOW]], [[FLOW]] : i32 +// CHECK-NEXT: [[SHIGH:%.+]] = arith.select [[ARG2]], [[THIGH]], [[FHIGH]] : i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[SLOW]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SHIGH]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK: return [[INS1]] : vector<2xi32> +func.func @select_scalar(%a : i64, %b : i64, %c : i1) -> i64 { + %r = arith.select %c, %a, %b : i64 + return %r : i64 +} + +// CHECK-LABEL: func.func @select_vector_whole +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: i1) +// CHECK-SAME: -> vector<3x2xi32> +// CHECK: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK-NEXT: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @select_vector_whole(%a : vector<3xi64>, %b : vector<3xi64>, %c : i1) -> vector<3xi64> { + %r = arith.select %c, %a, %b : vector<3xi64> + return %r : vector<3xi64> +} + +// CHECK-LABEL: func.func @select_vector_elementwise +// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: vector<3xi1>) +// CHECK-SAME: -> vector<3x2xi32> +// CHECK: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32> +// CHECK-NEXT: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c : vector<3xi1>) -> vector<3xi64> { + %r = arith.select %c, %a, %b : vector<3xi1>, vector<3xi64> + return %r : vector<3xi64> +} + // CHECK-LABEL: func.func @muli_scalar // CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> // CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>