diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -277,6 +277,7 @@ ``` }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1792,6 +1793,7 @@ def SubIOp : IntBinaryOp<"subi"> { let summary = "integer subtraction operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -283,6 +283,62 @@ })); } +/// Canonicalize a sum of a constant and (constant - something) to simply be +/// a sum of constants minus something. This transformation does similar +/// transformations for additions of a constant with a subtract/add of +/// a constant. This may result in some operations being reordered (but should +/// remain equivalent). +struct AddConstantReorder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddIOp addop, + PatternRewriter &rewriter) const override { + for (int i = 0; i < 2; i++) { + APInt origConst; + APInt midConst; + if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) { + if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp()) { + for (int j = 0; j < 2; j++) { + if (matchPattern(midAddOp.getOperand(j), + m_ConstantInt(&midConst))) { + auto nextConstant = rewriter.create( + addop.getLoc(), rewriter.getIntegerAttr( + addop.getType(), origConst + midConst)); + rewriter.replaceOpWithNewOp(addop, nextConstant, + midAddOp.getOperand(1 - j)); + return success(); + } + } + } + if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp()) { + if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) { + auto nextConstant = rewriter.create( + addop.getLoc(), + rewriter.getIntegerAttr(addop.getType(), origConst + midConst)); + rewriter.replaceOpWithNewOp(addop, nextConstant, + midSubOp.getOperand(1)); + return success(); + } + if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) { + auto nextConstant = rewriter.create( + addop.getLoc(), + rewriter.getIntegerAttr(addop.getType(), origConst - midConst)); + rewriter.replaceOpWithNewOp(addop, nextConstant, + midSubOp.getOperand(0)); + return success(); + } + } + } + } + return failure(); + } +}; + +void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// @@ -1706,6 +1762,153 @@ [](APInt a, APInt b) { return a - b; }); } +/// Canonicalize a sub of a constant and (constant +/- something) to simply be +/// a single operation that merges the two constants. +struct SubConstantReorder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubIOp subOp, + PatternRewriter &rewriter) const override { + APInt origConst; + APInt midConst; + + if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) { + if (auto midAddOp = subOp.getOperand(1).getDefiningOp()) { + // origConst - (midConst + something) == (origConst - midConst) - + // something + for (int j = 0; j < 2; j++) { + if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) { + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst - midConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midAddOp.getOperand(1 - j)); + return success(); + } + } + } + + if (auto midSubOp = subOp.getOperand(0).getDefiningOp()) { + if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) { + // (midConst - something) - origConst == (midConst - origConst) - + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), midConst - origConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(1)); + return success(); + } + + if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) { + // (something - midConst) - origConst == something - (origConst + + // midConst) + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst + midConst)); + rewriter.replaceOpWithNewOp(subOp, midSubOp.getOperand(0), + nextConstant); + return success(); + } + } + + if (auto midSubOp = subOp.getOperand(1).getDefiningOp()) { + if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) { + // origConst - (midConst - something) == (origConst - midConst) + + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst - midConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(1)); + return success(); + } + + if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) { + // origConst - (something - midConst) == (origConst + midConst) - + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst + midConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(0)); + return success(); + } + } + } + + if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) { + if (auto midAddOp = subOp.getOperand(0).getDefiningOp()) { + // (midConst + something) - origConst == (midConst - origConst) + + // something + for (int j = 0; j < 2; j++) { + if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) { + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), midConst - origConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midAddOp.getOperand(1 - j)); + return success(); + } + } + } + + if (auto midSubOp = subOp.getOperand(0).getDefiningOp()) { + if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) { + // (midConst - something) - origConst == (midConst - origConst) - + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), midConst - origConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(1)); + return success(); + } + + if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) { + // (something - midConst) - origConst == something - (midConst + + // origConst) + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), midConst + origConst)); + rewriter.replaceOpWithNewOp(subOp, midSubOp.getOperand(0), + nextConstant); + return success(); + } + } + + if (auto midSubOp = subOp.getOperand(1).getDefiningOp()) { + if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) { + // origConst - (midConst - something) == (origConst - midConst) + + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst - midConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(1)); + return success(); + } + if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) { + // origConst - (something - midConst) == (origConst - midConst) - + // something + auto nextConstant = rewriter.create( + subOp.getLoc(), + rewriter.getIntegerAttr(subOp.getType(), origConst - midConst)); + rewriter.replaceOpWithNewOp(subOp, nextConstant, + midSubOp.getOperand(0)); + return success(); + } + } + } + return failure(); + } +}; + +void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -428,3 +428,113 @@ %tr = trunci %c-2 : i32 to i16 return %tr : i16 } + +// ----- + +// CHECK-LABEL: @tripleAddAdd +// CHECK: %[[cres:.+]] = constant 59 : index +// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func @tripleAddAdd(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = addi %c17, %arg0 : index + %add2 = addi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleAddSub0 +// CHECK: %[[cres:.+]] = constant 59 : index +// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index +// CHECK: return %[[add]] +func @tripleAddSub0(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %c17, %arg0 : index + %add2 = addi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleAddSub1 +// CHECK: %[[cres:.+]] = constant 25 : index +// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func @tripleAddSub1(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %arg0, %c17 : index + %add2 = addi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubAdd0 +// CHECK: %[[cres:.+]] = constant 25 : index +// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index +// CHECK: return %[[add]] +func @tripleSubAdd0(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = addi %c17, %arg0 : index + %add2 = subi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubAdd1 +// CHECK: %[[cres:.+]] = constant -25 : index +// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func @tripleSubAdd1(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = addi %c17, %arg0 : index + %add2 = subi %add1, %c42 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubSub0 +// CHECK: %[[cres:.+]] = constant 25 : index +// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func @tripleSubSub0(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %c17, %arg0 : index + %add2 = subi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubSub1 +// CHECK: %[[cres:.+]] = constant -25 : index +// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index +// CHECK: return %[[add]] +func @tripleSubSub1(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %c17, %arg0 : index + %add2 = subi %add1, %c42 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubSub2 +// CHECK: %[[cres:.+]] = constant 59 : index +// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index +// CHECK: return %[[add]] +func @tripleSubSub2(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %arg0, %c17 : index + %add2 = subi %c42, %add1 : index + return %add2 : index +} + +// CHECK-LABEL: @tripleSubSub3 +// CHECK: %[[cres:.+]] = constant 59 : index +// CHECK: %[[add:.+]] = subi %arg0, %[[cres]] : index +// CHECK: return %[[add]] +func @tripleSubSub3(%arg0: index) -> index { + %c17 = constant 17 : index + %c42 = constant 42 : index + %add1 = subi %arg0, %c17 : index + %add2 = subi %add1, %c42 : index + return %add2 : index +}