Index: mlir/include/mlir/Dialect/Math/IR/MathOps.td =================================================================== --- mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -43,6 +43,17 @@ let assemblyFormat = "$operand attr-dict `:` type($result)"; } +// Base class for binary math operations on integer types. Require two +// operands and one result of the same type. This type can be an integer +// type, vector or tensor thereof. +class Math_IntegerBinaryOp traits = []> : + Math_Op { + let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); + let results = (outs SignlessIntegerLike:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; +} + // Base class for binary math operations on floating point types. Require two // operands and one result of the same type. This type can be a floating point // type, vector or tensor thereof. @@ -477,6 +488,33 @@ }]; } +//===----------------------------------------------------------------------===// +// IPowSIOp +//===----------------------------------------------------------------------===// + +def Math_IPowSIOp : Math_IntegerBinaryOp<"ipowsi"> { + let summary = "signed integer raised to the power of operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.ipowsi` ssa-use `,` ssa-use `:` type + ``` + + The `ipowsi` operation takes two operands of integer type (i.e., scalar, + tensor or vector) and returns one result of the same type. Operands + must have the same type. + + Example: + + ```mlir + // Scalar signed integer exponentiation. + %a = math.ipowsi %b, %c : i32 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Math/IR/MathOps.cpp =================================================================== --- mlir/lib/Dialect/Math/IR/MathOps.cpp +++ mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -87,6 +87,57 @@ }); } +//===----------------------------------------------------------------------===// +// IPowSIOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::IPowSIOp::fold(ArrayRef operands) { + // TODO: implement. + return constFoldBinaryOpConditional( + operands, [](const APInt &base, const APInt &power) -> Optional { + unsigned width = base.getBitWidth(); + auto zeroValue = APInt::getZero(width); + APInt oneValue{width, 1ULL, /*isSigned=*/true}; + APInt minusOneValue{width, -1ULL, /*isSigned=*/true}; + + if (power.isZero()) + return oneValue; + + if (power.isNegative()) { + // Leave 0 raised to negative power not folded. + if (base.isZero()) + return {}; + if (base.eq(oneValue)) + return oneValue; + // If abs(base) > 1, then the result is zero. + if (base.ne(minusOneValue)) + return zeroValue; + // base == -1: + // -1: power is odd + // 1: power is even + if (power[0] == 1) + return minusOneValue; + + return oneValue; + } + + // power is positive. + APInt result = oneValue; + APInt curBase = base; + APInt curPower = power; + while (true) { + if (curPower[0] == 1) + result *= curBase; + curPower.lshrInPlace(1); + if (curPower.isZero()) + return result; + curBase *= curBase; + } + }); + + return Attribute(); +} + //===----------------------------------------------------------------------===// // Log2Op folder //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp =================================================================== --- mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -112,9 +112,90 @@ return failure(); } +//----------------------------------------------------------------------------// +// IPowSIOp strength reduction. +//----------------------------------------------------------------------------// + +namespace { +struct IPowSIStrengthReduction : public OpRewritePattern { + unsigned exponentThreshold; + +public: + using OpRewritePattern::OpRewritePattern; + + IPowSIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OpRewritePattern(context, benefit, generatedNames), + exponentThreshold(exponentThreshold) {} + LogicalResult matchAndRewrite(math::IPowSIOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +IPowSIStrengthReduction::matchAndRewrite(math::IPowSIOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value base = op.getLhs(); + + IntegerAttr scalarExponent; + DenseIntElementsAttr vectorExponent; + + bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); + bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); + + int64_t exponentValue = 0; + if (isScalar) + exponentValue = scalarExponent.getInt(); + else if (isVector && vectorExponent.isSplat()) + exponentValue = vectorExponent.getSplatValue().getInt(); + else + return failure(); + + // Maybe broadcasts scalar value into vector type compatible with `op`. + auto bcast = [&](Value value) -> Value { + if (auto vec = op.getType().dyn_cast()) + return rewriter.create(loc, vec, value); + return value; + }; + + if (exponentValue == 0) { + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); + rewriter.replaceOp(op, bcast(one)); + return success(); + } + + bool exponentIsNegative = false; + if (exponentValue < 0) { + exponentIsNegative = true; + exponentValue *= -1; + } + + if (exponentValue > exponentThreshold) + return failure(); + + // Inverse the base for negative exponent. + if (exponentIsNegative) { + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1)); + base = rewriter.create(loc, bcast(one), base); + } + + Value result = base; + // Transform to naive sequence of multiplications. + for (unsigned i = 1; i < exponentValue; ++i) + result = rewriter.create(loc, result, base); + + rewriter.replaceOp(op, result); + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } Index: mlir/test/Dialect/Math/algebraic-simplification.mlir =================================================================== --- mlir/test/Dialect/Math/algebraic-simplification.mlir +++ mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -73,3 +73,93 @@ %1 = math.powf %arg1, %v : vector<4xf32> return %0, %1 : f32, vector<4xf32> } + +// CHECK-LABEL: @ipowi_zero_exp( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>) { +func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: return %[[CST_S]], %[[CST_V]] + %c = arith.constant 0 : i32 + %v = arith.constant dense <0> : vector<4xi32> + %0 = math.ipowsi %arg0, %c : i32 + %1 = math.ipowsi %arg1, %v : vector<4xi32> + return %0, %1 : i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_one( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]] + %c1 = arith.constant 1 : i32 + %v1 = arith.constant dense <1> : vector<4xi32> + %0 = math.ipowsi %arg0, %c1 : i32 + %1 = math.ipowsi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -1 : i32 + %vm1 = arith.constant dense <-1> : vector<4xi32> + %2 = math.ipowsi %arg0, %cm1 : i32 + %3 = math.ipowsi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_two( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + %c1 = arith.constant 2 : i32 + %v1 = arith.constant dense <2> : vector<4xi32> + %0 = math.ipowsi %arg0, %c1 : i32 + %1 = math.ipowsi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -2 : i32 + %vm1 = arith.constant dense <-2> : vector<4xi32> + %2 = math.ipowsi %arg0, %cm1 : i32 + %3 = math.ipowsi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +} + +// CHECK-LABEL: @ipowi_exp_three( +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> +// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { +func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { + // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32 + // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]] + // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] + // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] + // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] + // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + %c1 = arith.constant 3 : i32 + %v1 = arith.constant dense <3> : vector<4xi32> + %0 = math.ipowsi %arg0, %c1 : i32 + %1 = math.ipowsi %arg1, %v1 : vector<4xi32> + %cm1 = arith.constant -3 : i32 + %vm1 = arith.constant dense <-3> : vector<4xi32> + %2 = math.ipowsi %arg0, %cm1 : i32 + %3 = math.ipowsi %arg1, %vm1 : vector<4xi32> + return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32> +} Index: mlir/test/Dialect/Math/canonicalize_ipowsi.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Math/canonicalize_ipowsi.mlir @@ -0,0 +1,442 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: @ipowsi32_fold( +// CHECK-SAME: %[[result:.+]]: memref +func.func @ipowsi32_fold(%result : memref) { +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[cst1:.+]] = arith.constant 1 : i32 +// CHECK-DAG: %[[cst1073741824:.+]] = arith.constant 1073741824 : i32 +// CHECK-DAG: %[[cst_m1:.+]] = arith.constant -1 : i32 +// CHECK-DAG: %[[cst_m27:.+]] = arith.constant -27 : i32 +// CHECK-DAG: %[[i0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[i1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[i2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[i3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[i4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[i5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[i6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[i7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[i8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[i9:.+]] = arith.constant 9 : index +// CHECK-DAG: %[[i10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[i11:.+]] = arith.constant 11 : index + +// --- Test power == 0 --- + %arg0_base = arith.constant 0 : i32 + %arg0_power = arith.constant 0 : i32 + %res0 = math.ipowsi %arg0_base, %arg0_power : i32 + %i0 = arith.constant 0 : index + memref.store %res0, %result[%i0] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i0]]] : memref + + %arg1_base = arith.constant 10 : i32 + %arg1_power = arith.constant 0 : i32 + %res1 = math.ipowsi %arg1_base, %arg1_power : i32 + %i1 = arith.constant 1 : index + memref.store %res1, %result[%i1] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i1]]] : memref + + %arg2_base = arith.constant -10 : i32 + %arg2_power = arith.constant 0 : i32 + %res2 = math.ipowsi %arg2_base, %arg2_power : i32 + %i2 = arith.constant 2 : index + memref.store %res2, %result[%i2] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i2]]] : memref + +// --- Test negative powers --- + %arg3_base = arith.constant 0 : i32 + %arg3_power = arith.constant -1 : i32 + %res3 = math.ipowsi %arg3_base, %arg3_power : i32 + %i3 = arith.constant 3 : index + memref.store %res3, %result[%i3] : memref +// No folding for ipowsi(0, x) for x < 0: +// CHECK: %[[res3:.+]] = math.ipowsi %[[cst0]], %[[cst_m1]] : i32 +// CHECK: memref.store %[[res3]], %[[result]][%[[i3]]] : memref + + %arg4_base = arith.constant 1 : i32 + %arg4_power = arith.constant -10 : i32 + %res4 = math.ipowsi %arg4_base, %arg4_power : i32 + %i4 = arith.constant 4 : index + memref.store %res4, %result[%i4] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i4]]] : memref + + %arg5_base = arith.constant 2 : i32 + %arg5_power = arith.constant -1 : i32 + %res5 = math.ipowsi %arg5_base, %arg5_power : i32 + %i5 = arith.constant 5 : index + memref.store %res5, %result[%i5] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i5]]] : memref + + %arg6_base = arith.constant -2 : i32 + %arg6_power = arith.constant -1 : i32 + %res6 = math.ipowsi %arg6_base, %arg6_power : i32 + %i6 = arith.constant 6 : index + memref.store %res6, %result[%i6] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i6]]] : memref + + %arg7_base = arith.constant -1 : i32 + %arg7_power = arith.constant -10 : i32 + %res7 = math.ipowsi %arg7_base, %arg7_power : i32 + %i7 = arith.constant 7 : index + memref.store %res7, %result[%i7] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i7]]] : memref + + %arg8_base = arith.constant -1 : i32 + %arg8_power = arith.constant -11 : i32 + %res8 = math.ipowsi %arg8_base, %arg8_power : i32 + %i8 = arith.constant 8 : index + memref.store %res8, %result[%i8] : memref +// CHECK: memref.store %[[cst_m1]], %[[result]][%[[i8]]] : memref + +// --- Test positive powers --- + %arg9_base = arith.constant -3 : i32 + %arg9_power = arith.constant 3 : i32 + %res9 = math.ipowsi %arg9_base, %arg9_power : i32 + %i9 = arith.constant 9 : index + memref.store %res9, %result[%i9] : memref +// CHECK: memref.store %[[cst_m27]], %[[result]][%[[i9]]] : memref + + %arg10_base = arith.constant 2 : i32 + %arg10_power = arith.constant 30 : i32 + %res10 = math.ipowsi %arg10_base, %arg10_power : i32 + %i10 = arith.constant 10 : index + memref.store %res10, %result[%i10] : memref +// CHECK: memref.store %[[cst1073741824]], %[[result]][%[[i10]]] : memref + +// --- Test vector folding --- + %arg11_base = arith.constant 2 : i32 + %arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32> + %arg11_power = arith.constant 30 : i32 + %arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32> + %res11_vec = math.ipowsi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32> + %i11 = arith.constant 11 : index + %res11 = vector.extract %res11_vec[1, 1] : vector<2x2xi32> + memref.store %res11, %result[%i11] : memref +// CHECK: memref.store %[[cst1073741824]], %[[result]][%[[i11]]] : memref + + return +} + +// CHECK-LABEL: @ipowsi64_fold( +// CHECK-SAME: %[[result:.+]]: memref +func.func @ipowsi64_fold(%result : memref) { +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0 : i64 +// CHECK-DAG: %[[cst1:.+]] = arith.constant 1 : i64 +// CHECK-DAG: %[[cst1073741824:.+]] = arith.constant 1073741824 : i64 +// CHECK-DAG: %[[cst281474976710656:.+]] = arith.constant 281474976710656 : i64 +// CHECK-DAG: %[[cst_m1:.+]] = arith.constant -1 : i64 +// CHECK-DAG: %[[cst_m27:.+]] = arith.constant -27 : i64 +// CHECK-DAG: %[[i0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[i1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[i2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[i3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[i4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[i5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[i6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[i7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[i8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[i9:.+]] = arith.constant 9 : index +// CHECK-DAG: %[[i10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[i11:.+]] = arith.constant 11 : index + +// --- Test power == 0 --- + %arg0_base = arith.constant 0 : i64 + %arg0_power = arith.constant 0 : i64 + %res0 = math.ipowsi %arg0_base, %arg0_power : i64 + %i0 = arith.constant 0 : index + memref.store %res0, %result[%i0] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i0]]] : memref + + %arg1_base = arith.constant 10 : i64 + %arg1_power = arith.constant 0 : i64 + %res1 = math.ipowsi %arg1_base, %arg1_power : i64 + %i1 = arith.constant 1 : index + memref.store %res1, %result[%i1] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i1]]] : memref + + %arg2_base = arith.constant -10 : i64 + %arg2_power = arith.constant 0 : i64 + %res2 = math.ipowsi %arg2_base, %arg2_power : i64 + %i2 = arith.constant 2 : index + memref.store %res2, %result[%i2] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i2]]] : memref + +// --- Test negative powers --- + %arg3_base = arith.constant 0 : i64 + %arg3_power = arith.constant -1 : i64 + %res3 = math.ipowsi %arg3_base, %arg3_power : i64 + %i3 = arith.constant 3 : index + memref.store %res3, %result[%i3] : memref +// No folding for ipowsi(0, x) for x < 0: +// CHECK: %[[res3:.+]] = math.ipowsi %[[cst0]], %[[cst_m1]] : i64 +// CHECK: memref.store %[[res3]], %[[result]][%[[i3]]] : memref + + %arg4_base = arith.constant 1 : i64 + %arg4_power = arith.constant -10 : i64 + %res4 = math.ipowsi %arg4_base, %arg4_power : i64 + %i4 = arith.constant 4 : index + memref.store %res4, %result[%i4] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i4]]] : memref + + %arg5_base = arith.constant 2 : i64 + %arg5_power = arith.constant -1 : i64 + %res5 = math.ipowsi %arg5_base, %arg5_power : i64 + %i5 = arith.constant 5 : index + memref.store %res5, %result[%i5] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i5]]] : memref + + %arg6_base = arith.constant -2 : i64 + %arg6_power = arith.constant -1 : i64 + %res6 = math.ipowsi %arg6_base, %arg6_power : i64 + %i6 = arith.constant 6 : index + memref.store %res6, %result[%i6] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i6]]] : memref + + %arg7_base = arith.constant -1 : i64 + %arg7_power = arith.constant -10 : i64 + %res7 = math.ipowsi %arg7_base, %arg7_power : i64 + %i7 = arith.constant 7 : index + memref.store %res7, %result[%i7] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i7]]] : memref + + %arg8_base = arith.constant -1 : i64 + %arg8_power = arith.constant -11 : i64 + %res8 = math.ipowsi %arg8_base, %arg8_power : i64 + %i8 = arith.constant 8 : index + memref.store %res8, %result[%i8] : memref +// CHECK: memref.store %[[cst_m1]], %[[result]][%[[i8]]] : memref + +// --- Test positive powers --- + %arg9_base = arith.constant -3 : i64 + %arg9_power = arith.constant 3 : i64 + %res9 = math.ipowsi %arg9_base, %arg9_power : i64 + %i9 = arith.constant 9 : index + memref.store %res9, %result[%i9] : memref +// CHECK: memref.store %[[cst_m27]], %[[result]][%[[i9]]] : memref + + %arg10_base = arith.constant 2 : i64 + %arg10_power = arith.constant 30 : i64 + %res10 = math.ipowsi %arg10_base, %arg10_power : i64 + %i10 = arith.constant 10 : index + memref.store %res10, %result[%i10] : memref +// CHECK: memref.store %[[cst1073741824]], %[[result]][%[[i10]]] : memref + + %arg11_base = arith.constant 2 : i64 + %arg11_power = arith.constant 48 : i64 + %res11 = math.ipowsi %arg11_base, %arg11_power : i64 + %i11 = arith.constant 11 : index + memref.store %res11, %result[%i11] : memref +// CHECK: memref.store %[[cst281474976710656]], %[[result]][%[[i11]]] : memref + + return +} + +// CHECK-LABEL: @ipowsi16_fold( +// CHECK-SAME: %[[result:.+]]: memref +func.func @ipowsi16_fold(%result : memref) { +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0 : i16 +// CHECK-DAG: %[[cst1:.+]] = arith.constant 1 : i16 +// CHECK-DAG: %[[cst16384:.+]] = arith.constant 16384 : i16 +// CHECK-DAG: %[[cst_m1:.+]] = arith.constant -1 : i16 +// CHECK-DAG: %[[cst_m27:.+]] = arith.constant -27 : i16 +// CHECK-DAG: %[[i0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[i1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[i2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[i3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[i4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[i5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[i6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[i7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[i8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[i9:.+]] = arith.constant 9 : index +// CHECK-DAG: %[[i10:.+]] = arith.constant 10 : index + +// --- Test power == 0 --- + %arg0_base = arith.constant 0 : i16 + %arg0_power = arith.constant 0 : i16 + %res0 = math.ipowsi %arg0_base, %arg0_power : i16 + %i0 = arith.constant 0 : index + memref.store %res0, %result[%i0] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i0]]] : memref + + %arg1_base = arith.constant 10 : i16 + %arg1_power = arith.constant 0 : i16 + %res1 = math.ipowsi %arg1_base, %arg1_power : i16 + %i1 = arith.constant 1 : index + memref.store %res1, %result[%i1] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i1]]] : memref + + %arg2_base = arith.constant -10 : i16 + %arg2_power = arith.constant 0 : i16 + %res2 = math.ipowsi %arg2_base, %arg2_power : i16 + %i2 = arith.constant 2 : index + memref.store %res2, %result[%i2] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i2]]] : memref + +// --- Test negative powers --- + %arg3_base = arith.constant 0 : i16 + %arg3_power = arith.constant -1 : i16 + %res3 = math.ipowsi %arg3_base, %arg3_power : i16 + %i3 = arith.constant 3 : index + memref.store %res3, %result[%i3] : memref +// No folding for ipowsi(0, x) for x < 0: +// CHECK: %[[res3:.+]] = math.ipowsi %[[cst0]], %[[cst_m1]] : i16 +// CHECK: memref.store %[[res3]], %[[result]][%[[i3]]] : memref + + %arg4_base = arith.constant 1 : i16 + %arg4_power = arith.constant -10 : i16 + %res4 = math.ipowsi %arg4_base, %arg4_power : i16 + %i4 = arith.constant 4 : index + memref.store %res4, %result[%i4] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i4]]] : memref + + %arg5_base = arith.constant 2 : i16 + %arg5_power = arith.constant -1 : i16 + %res5 = math.ipowsi %arg5_base, %arg5_power : i16 + %i5 = arith.constant 5 : index + memref.store %res5, %result[%i5] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i5]]] : memref + + %arg6_base = arith.constant -2 : i16 + %arg6_power = arith.constant -1 : i16 + %res6 = math.ipowsi %arg6_base, %arg6_power : i16 + %i6 = arith.constant 6 : index + memref.store %res6, %result[%i6] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i6]]] : memref + + %arg7_base = arith.constant -1 : i16 + %arg7_power = arith.constant -10 : i16 + %res7 = math.ipowsi %arg7_base, %arg7_power : i16 + %i7 = arith.constant 7 : index + memref.store %res7, %result[%i7] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i7]]] : memref + + %arg8_base = arith.constant -1 : i16 + %arg8_power = arith.constant -11 : i16 + %res8 = math.ipowsi %arg8_base, %arg8_power : i16 + %i8 = arith.constant 8 : index + memref.store %res8, %result[%i8] : memref +// CHECK: memref.store %[[cst_m1]], %[[result]][%[[i8]]] : memref + +// --- Test positive powers --- + %arg9_base = arith.constant -3 : i16 + %arg9_power = arith.constant 3 : i16 + %res9 = math.ipowsi %arg9_base, %arg9_power : i16 + %i9 = arith.constant 9 : index + memref.store %res9, %result[%i9] : memref +// CHECK: memref.store %[[cst_m27]], %[[result]][%[[i9]]] : memref + + %arg10_base = arith.constant 2 : i16 + %arg10_power = arith.constant 14 : i16 + %res10 = math.ipowsi %arg10_base, %arg10_power : i16 + %i10 = arith.constant 10 : index + memref.store %res10, %result[%i10] : memref +// CHECK: memref.store %[[cst16384]], %[[result]][%[[i10]]] : memref + + return +} + +// CHECK-LABEL: @ipowsi8_fold( +// CHECK-SAME: %[[result:.+]]: memref +func.func @ipowsi8_fold(%result : memref) { +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[cst1:.+]] = arith.constant 1 : i8 +// CHECK-DAG: %[[cst64:.+]] = arith.constant 64 : i8 +// CHECK-DAG: %[[cst_m1:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[cst_m27:.+]] = arith.constant -27 : i8 +// CHECK-DAG: %[[i0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[i1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[i2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[i3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[i4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[i5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[i6:.+]] = arith.constant 6 : index +// CHECK-DAG: %[[i7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[i8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[i9:.+]] = arith.constant 9 : index +// CHECK-DAG: %[[i10:.+]] = arith.constant 10 : index + +// --- Test power == 0 --- + %arg0_base = arith.constant 0 : i8 + %arg0_power = arith.constant 0 : i8 + %res0 = math.ipowsi %arg0_base, %arg0_power : i8 + %i0 = arith.constant 0 : index + memref.store %res0, %result[%i0] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i0]]] : memref + + %arg1_base = arith.constant 10 : i8 + %arg1_power = arith.constant 0 : i8 + %res1 = math.ipowsi %arg1_base, %arg1_power : i8 + %i1 = arith.constant 1 : index + memref.store %res1, %result[%i1] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i1]]] : memref + + %arg2_base = arith.constant -10 : i8 + %arg2_power = arith.constant 0 : i8 + %res2 = math.ipowsi %arg2_base, %arg2_power : i8 + %i2 = arith.constant 2 : index + memref.store %res2, %result[%i2] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i2]]] : memref + +// --- Test negative powers --- + %arg3_base = arith.constant 0 : i8 + %arg3_power = arith.constant -1 : i8 + %res3 = math.ipowsi %arg3_base, %arg3_power : i8 + %i3 = arith.constant 3 : index + memref.store %res3, %result[%i3] : memref +// No folding for ipowsi(0, x) for x < 0: +// CHECK: %[[res3:.+]] = math.ipowsi %[[cst0]], %[[cst_m1]] : i8 +// CHECK: memref.store %[[res3]], %[[result]][%[[i3]]] : memref + + %arg4_base = arith.constant 1 : i8 + %arg4_power = arith.constant -10 : i8 + %res4 = math.ipowsi %arg4_base, %arg4_power : i8 + %i4 = arith.constant 4 : index + memref.store %res4, %result[%i4] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i4]]] : memref + + %arg5_base = arith.constant 2 : i8 + %arg5_power = arith.constant -1 : i8 + %res5 = math.ipowsi %arg5_base, %arg5_power : i8 + %i5 = arith.constant 5 : index + memref.store %res5, %result[%i5] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i5]]] : memref + + %arg6_base = arith.constant -2 : i8 + %arg6_power = arith.constant -1 : i8 + %res6 = math.ipowsi %arg6_base, %arg6_power : i8 + %i6 = arith.constant 6 : index + memref.store %res6, %result[%i6] : memref +// CHECK: memref.store %[[cst0]], %[[result]][%[[i6]]] : memref + + %arg7_base = arith.constant -1 : i8 + %arg7_power = arith.constant -10 : i8 + %res7 = math.ipowsi %arg7_base, %arg7_power : i8 + %i7 = arith.constant 7 : index + memref.store %res7, %result[%i7] : memref +// CHECK: memref.store %[[cst1]], %[[result]][%[[i7]]] : memref + + %arg8_base = arith.constant -1 : i8 + %arg8_power = arith.constant -11 : i8 + %res8 = math.ipowsi %arg8_base, %arg8_power : i8 + %i8 = arith.constant 8 : index + memref.store %res8, %result[%i8] : memref +// CHECK: memref.store %[[cst_m1]], %[[result]][%[[i8]]] : memref + +// --- Test positive powers --- + %arg9_base = arith.constant -3 : i8 + %arg9_power = arith.constant 3 : i8 + %res9 = math.ipowsi %arg9_base, %arg9_power : i8 + %i9 = arith.constant 9 : index + memref.store %res9, %result[%i9] : memref +// CHECK: memref.store %[[cst_m27]], %[[result]][%[[i9]]] : memref + + %arg10_base = arith.constant 2 : i8 + %arg10_power = arith.constant 6 : i8 + %res10 = math.ipowsi %arg10_base, %arg10_power : i8 + %i10 = arith.constant 10 : index + memref.store %res10, %result[%i10] : memref +// CHECK: memref.store %[[cst64]], %[[result]][%[[i10]]] : memref + + return +} Index: mlir/test/Dialect/Math/ops.mlir =================================================================== --- mlir/test/Dialect/Math/ops.mlir +++ mlir/test/Dialect/Math/ops.mlir @@ -206,3 +206,15 @@ %2 = math.round %t : tensor<4x4x?xf32> return } + +// CHECK-LABEL: func @ipowsi( +// CHECK-SAME: %[[I:.*]]: i32, %[[V:.*]]: vector<4xi32>, %[[T:.*]]: tensor<4x4x?xi32>) +func.func @ipowsi(%i: i32, %v: vector<4xi32>, %t: tensor<4x4x?xi32>) { + // CHECK: %{{.*}} = math.ipowsi %[[I]], %[[I]] : i32 + %0 = math.ipowsi %i, %i : i32 + // CHECK: %{{.*}} = math.ipowsi %[[V]], %[[V]] : vector<4xi32> + %1 = math.ipowsi %v, %v : vector<4xi32> + // CHECK: %{{.*}} = math.ipowsi %[[T]], %[[T]] : tensor<4x4x?xi32> + %2 = math.ipowsi %t, %t : tensor<4x4x?xi32> + return +}