diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -818,6 +818,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; let verifier = [{ return verifyExtOp(*this); }]; } diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -128,4 +128,12 @@ def BitcastOfBitcast : Pat<(Arith_BitcastOp (Arith_BitcastOp $x)), (replaceWithValue $x)>; +//===----------------------------------------------------------------------===// +// ExtSIOp +//===----------------------------------------------------------------------===// + +// extsi(extui(x iN : iM) : iL) -> extui(x : iL) +def ExtSIOfExtUI : + Pat<(Arith_ExtSIOp (Arith_ExtUIOp $x)), (Arith_ExtUIOp $x)>; + #endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -788,6 +788,11 @@ return IntegerAttr::get( getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); + if (auto lhs = getIn().getDefiningOp()) { + getInMutable().assign(lhs.getIn()); + return getResult(); + } + return {}; } @@ -804,6 +809,11 @@ return IntegerAttr::get( getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); + if (auto lhs = getIn().getDefiningOp()) { + getInMutable().assign(lhs.getIn()); + return getResult(); + } + return {}; } @@ -811,6 +821,11 @@ return checkWidthChangeCast(inputs, outputs); } +void arith::ExtSIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -70,6 +70,35 @@ // ----- +// CHECK-LABEL: @extSIOfExtUI +// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64 +// CHECK: return %[[res]] +func @extSIOfExtUI(%arg0: i1) -> i64 { + %ext1 = arith.extui %arg0 : i1 to i8 + %ext2 = arith.extsi %ext1 : i8 to i64 + return %ext2 : i64 +} + +// CHECK-LABEL: @extUIOfExtUI +// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64 +// CHECK: return %[[res]] +func @extUIOfExtUI(%arg0: i1) -> i64 { + %ext1 = arith.extui %arg0 : i1 to i8 + %ext2 = arith.extui %ext1 : i8 to i64 + return %ext2 : i64 +} + +// CHECK-LABEL: @extSIOfExtSI +// CHECK: %[[res:.+]] = arith.extsi %arg0 : i1 to i64 +// CHECK: return %[[res]] +func @extSIOfExtSI(%arg0: i1) -> i64 { + %ext1 = arith.extsi %arg0 : i1 to i8 + %ext2 = arith.extsi %ext1 : i8 to i64 + return %ext2 : i64 +} + +// ----- + // CHECK-LABEL: @indexCastOfSignExtend // CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index // CHECK: return %[[res]]