diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -285,6 +285,7 @@ def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> { let summary = "integer multiplication operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -18,9 +18,12 @@ // Add two integer attributes and create a new one with the result. def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">; -// Subtract two integer attributes and createa a new one with the result. +// Subtract two integer attributes and create a new one with the result. def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">; +// Multiply two integer attributes and create a new one with the result. +def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">; + class cast : NativeCodeCall<"::mlir::cast<" # type # ">($0)">; //===----------------------------------------------------------------------===// @@ -72,6 +75,13 @@ (Arith_SubIOp $y, $x), [(IsScalarOrSplatNegativeOne $c0)]>; +// muli(muli(x, c0), c1) -> muli(x, c0 * c1) +def MulIMulIConstant : + Pat<(Arith_MulIOp:$res + (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)), + (ConstantLikeMatcher APIntAttr:$c1)), + (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>; + //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include "mlir/Dialect/Arith/IR/Arith.h" @@ -34,18 +35,29 @@ // Pattern helpers //===----------------------------------------------------------------------===// +static IntegerAttr +applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, + Attribute rhs, + function_ref binFn) { + return builder.getIntegerAttr(res.getType(), + binFn(llvm::cast(lhs).getInt(), + llvm::cast(rhs).getInt())); +} + static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return builder.getIntegerAttr(res.getType(), - llvm::cast(lhs).getInt() + - llvm::cast(rhs).getInt()); + return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return builder.getIntegerAttr(res.getType(), - llvm::cast(lhs).getInt() - - llvm::cast(rhs).getInt()); + return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus()); +} + +static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return applyToIntegerAttrs(builder, res, lhs, rhs, + std::multiplies()); } /// Invert an integer comparison predicate. @@ -382,6 +394,11 @@ [](const APInt &a, const APInt &b) { return a * b; }); } +void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MulSIExtendedOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -885,6 +885,30 @@ return %add : index } +// CHECK-LABEL: @tripleMulIMulIIndex +// CHECK: %[[cres:.+]] = arith.constant 15 : index +// CHECK: %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index +// CHECK: return %[[muli]] +func.func @tripleMulIMulIIndex(%arg0: index) -> index { + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %mul1 = arith.muli %arg0, %c3 : index + %mul2 = arith.muli %mul1, %c5 : index + return %mul2 : index +} + +// CHECK-LABEL: @tripleMulIMulII32 +// CHECK: %[[cres:.+]] = arith.constant -21 : i32 +// CHECK: %[[muli:.+]] = arith.muli %arg0, %[[cres]] : i32 +// CHECK: return %[[muli]] +func.func @tripleMulIMulII32(%arg0: i32) -> i32 { + %c_n3 = arith.constant -3 : i32 + %c7 = arith.constant 7 : i32 + %mul1 = arith.muli %arg0, %c_n3 : i32 + %mul2 = arith.muli %mul1, %c7 : i32 + return %mul2 : i32 +} + // CHECK-LABEL: @addiMuliToSubiRhsI32 // CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) // CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32