diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -118,6 +118,19 @@ } }; +/// The matcher that matches a given target constant scalar / vector splat / +/// tensor splat integer value. The value it self need not be a compile time +/// constant. +struct constant_int_dynamic_value_matcher { + constant_int_dynamic_value_matcher(APInt targetValue) + : targetValue(targetValue) {} + bool match(Operation *op) { + APInt value; + return constant_int_op_binder(&value).match(op) && targetValue == value; + } + APInt targetValue; +}; + /// The matcher that matches anything except the given target constant scalar / /// vector splat / tensor splat integer value. template struct constant_int_not_value_matcher { @@ -223,6 +236,12 @@ return detail::constant_int_value_matcher<1>(); } +/// Matches a constant scalar / vector splat / tensor splat integer one. +inline detail::constant_int_dynamic_value_matcher +m_ConstantIntValue(APInt targetValue) { + return detail::constant_int_dynamic_value_matcher(targetValue); +} + /// Matches the given OpClass. template inline detail::op_matcher m_Op() { return detail::op_matcher(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -426,6 +426,24 @@ /// and(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); + /// and(x, allOnes) -> x + auto getIntegerType = [](Value v) -> IntegerType { + Type t = v.getType(); + if (t.isa()) + return t.cast(); + if (VectorType vType = t.dyn_cast()) { + return vType.getElementType().dyn_cast(); + } + if (TensorType tType = t.dyn_cast()) { + return tType.getElementType().dyn_cast(); + } + return nullptr; + }; + if (IntegerType intType = getIntegerType(this->getResult())) { + APInt allOnes = APInt::getAllOnesValue(intType.getWidth()); + if (matchPattern(rhs(), m_ConstantIntValue(allOnes))) + return lhs(); + } /// and(x,x) -> x if (lhs() == rhs()) return rhs(); @@ -2249,6 +2267,9 @@ // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); + // subi(x,0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -92,6 +92,77 @@ // ----- +// CHECK: func @simple_and +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: i1 +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]: i32) +func @simple_and(%arg0 : i1, %arg1 : i32) -> (i1, i32) { + %c1 = constant 1 : i1 + %cAllOnes_32 = constant 4294967295 : i32 + + // CHECK: [[C31:%.*]] = constant 31 : i32 + %c31 = constant 31 : i32 + %1 = and %arg0, %c1 : i1 + %2 = and %arg1, %cAllOnes_32 : i32 + + // CHECK: [[VAL:%.*]] = and [[ARG1]], [[C31]] + %3 = and %2, %c31 : i32 + + // CHECK: return [[ARG0]], [[VAL]] + return %1, %3 : i1, i32 +} + +// ----- + +// CHECK: func @tensor_and +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: tensor<2xi32> +func @tensor_and(%arg0 : tensor<2xi32>) -> tensor<2xi32> { + %cAllOnes_32 = constant dense<4294967295> : tensor<2xi32> + + // CHECK: [[C31:%.*]] = constant dense<31> : tensor<2xi32> + %c31 = constant dense<31> : tensor<2xi32> + + // CHECK: [[CMIXED:%.*]] = constant dense<[31, -1]> : tensor<2xi32> + %c_mixed = constant dense<[31, 4294967295]> : tensor<2xi32> + + %0 = and %arg0, %cAllOnes_32 : tensor<2xi32> + + // CHECK: [[T1:%.*]] = and [[ARG0]], [[C31]] + %1 = and %0, %c31 : tensor<2xi32> + + // CHECK: [[T2:%.*]] = and [[T1]], [[CMIXED]] + %2 = and %1, %c_mixed : tensor<2xi32> + + // CHECK: return [[T2]] + return %2 : tensor<2xi32> +} + +// ----- + +// CHECK: func @vector_and +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: vector<2xi32> +func @vector_and(%arg0 : vector<2xi32>) -> vector<2xi32> { + %cAllOnes_32 = constant dense<4294967295> : vector<2xi32> + + // CHECK: [[C31:%.*]] = constant dense<31> : vector<2xi32> + %c31 = constant dense<31> : vector<2xi32> + + // CHECK: [[CMIXED:%.*]] = constant dense<[31, -1]> : vector<2xi32> + %c_mixed = constant dense<[31, 4294967295]> : vector<2xi32> + + %0 = and %arg0, %cAllOnes_32 : vector<2xi32> + + // CHECK: [[T1:%.*]] = and [[ARG0]], [[C31]] + %1 = and %0, %c31 : vector<2xi32> + + // CHECK: [[T2:%.*]] = and [[T1]], [[CMIXED]] + %2 = and %1, %c_mixed : vector<2xi32> + + // CHECK: return [[T2]] + return %2 : vector<2xi32> +} + +// ----- + // CHECK-LABEL: func @addi_splat_vector func @addi_splat_vector() -> vector<8xi32> { %0 = constant dense<1> : vector<8xi32> @@ -134,16 +205,19 @@ // ----- -// CHECK-LABEL: func @simple_subi -func @simple_subi() -> i32 { +// CHECK: func @simple_subi +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +func @simple_subi(%arg0 : i32) -> (i32, i32) { %0 = constant 4 : i32 %1 = constant 1 : i32 + %2 = constant 0 : i32 // CHECK-NEXT:[[C3:%.+]] = constant 3 : i32 - %2 = subi %0, %1 : i32 + %3 = subi %0, %1 : i32 + %4 = subi %arg0, %2 : i32 - // CHECK-NEXT: return [[C3]] - return %2 : i32 + // CHECK-NEXT: return [[C3]], [[ARG0]] + return %3, %4 : i32, i32 } // -----