diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -19,8 +19,9 @@ #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" -#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::arith; @@ -2157,6 +2158,33 @@ return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; } } + + // Constant-fold constant operands over non-splat constant condition. + // select %cst_vec, %cst0, %cst1 => %cst2 + if (auto cond = + adaptor.getCondition().dyn_cast_or_null()) { + if (auto lhs = + adaptor.getTrueValue().dyn_cast_or_null()) { + if (auto rhs = + adaptor.getFalseValue().dyn_cast_or_null()) { + SmallVector results; + results.reserve(static_cast(cond.getNumElements())); + auto condVals = llvm::make_range(cond.value_begin(), + cond.value_end()); + auto lhsVals = llvm::make_range(lhs.value_begin(), + lhs.value_end()); + auto rhsVals = llvm::make_range(rhs.value_begin(), + rhs.value_end()); + + for (auto [condVal, lhsVal, rhsVal] : + llvm::zip_equal(condVals, lhsVals, rhsVals)) + results.push_back(condVal.getValue() ? lhsVal : rhsVal); + + return DenseElementsAttr::get(lhs.getType(), results); + } + } + } + return nullptr; } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -54,6 +54,57 @@ return %res : i1 } +// CHECK-LABEL: @select_cst_false_scalar +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) +// CHECK-NEXT: return %[[ARG1]] +func.func @select_cst_false_scalar(%arg0: i32, %arg1: i32) -> i32 { + %false = arith.constant false + %res = arith.select %false, %arg0, %arg1 : i32 + return %res : i32 +} + +// CHECK-LABEL: @select_cst_true_scalar +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) +// CHECK-NEXT: return %[[ARG0]] +func.func @select_cst_true_scalar(%arg0: i32, %arg1: i32) -> i32 { + %true = arith.constant true + %res = arith.select %true, %arg0, %arg1 : i32 + return %res : i32 +} + +// CHECK-LABEL: @select_cst_true_splat +// CHECK: %[[A:.+]] = arith.constant dense<[1, 2, 3]> : vector<3xi32> +// CHECK-NEXT: return %[[A]] +func.func @select_cst_true_splat() -> vector<3xi32> { + %cond = arith.constant dense : vector<3xi1> + %a = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %b = arith.constant dense<[4, 5, 6]> : vector<3xi32> + %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32> + return %res : vector<3xi32> +} + +// CHECK-LABEL: @select_cst_vector_i32 +// CHECK: %[[RES:.+]] = arith.constant dense<[1, 5, 3]> : vector<3xi32> +// CHECK-NEXT: return %[[RES]] +func.func @select_cst_vector_i32() -> vector<3xi32> { + %cond = arith.constant dense<[true, false, true]> : vector<3xi1> + %a = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %b = arith.constant dense<[4, 5, 6]> : vector<3xi32> + %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xi32> + return %res : vector<3xi32> +} + +// CHECK-LABEL: @select_cst_vector_f32 +// CHECK: %[[RES:.+]] = arith.constant dense<[4.000000e+00, 2.000000e+00, 6.000000e+00]> : vector<3xf32> +// CHECK-NEXT: return %[[RES]] +func.func @select_cst_vector_f32() -> vector<3xf32> { + %cond = arith.constant dense<[false, true, false]> : vector<3xi1> + %a = arith.constant dense<[1.0, 2.0, 3.0]> : vector<3xf32> + %b = arith.constant dense<[4.0, 5.0, 6.0]> : vector<3xf32> + %res = arith.select %cond, %a, %b : vector<3xi1>, vector<3xf32> + return %res : vector<3xf32> +} + // CHECK-LABEL: @selToNot // CHECK: %[[trueval:.+]] = arith.constant true // CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1