diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1249,6 +1249,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1711,6 +1712,7 @@ let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1241,6 +1241,29 @@ return {}; } +namespace { +/// index_cast(sign_extend x) => index_cast(x) +struct IndexCastOfSExt : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexCastOp op, + PatternRewriter &rewriter) const override { + + if (auto extop = op.getOperand().getDefiningOp()) { + op.setOperand(extop.getOperand()); + return success(); + } + return failure(); + } +}; + +} // namespace + +void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -1439,6 +1462,20 @@ return success(); } +OpFoldResult SignExtendIOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary operation takes one operand"); + + if (!operands[0]) + return {}; + + if (auto lhs = operands[0].dyn_cast()) { + return IntegerAttr::get( + getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); + } + + return {}; +} + //===----------------------------------------------------------------------===// // SignedDivIOp //===----------------------------------------------------------------------===// @@ -2686,7 +2723,18 @@ matchPattern(getOperand(), m_Op())) return getOperand().getDefiningOp()->getOperand(0); - return nullptr; + assert(operands.size() == 1 && "unary operation takes one operand"); + + if (!operands[0]) + return {}; + + if (auto lhs = operands[0].dyn_cast()) { + + return IntegerAttr::get( + getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -698,15 +698,13 @@ // Check sign and zero extension and truncation of integers. // CHECK-LABEL: @integer_extension_and_truncation -func @integer_extension_and_truncation() { -// CHECK-NEXT: %0 = llvm.mlir.constant(-3 : i3) : i3 - %0 = constant 5 : i3 -// CHECK-NEXT: = llvm.sext %0 : i3 to i6 - %1 = sexti %0 : i3 to i6 -// CHECK-NEXT: = llvm.zext %0 : i3 to i6 - %2 = zexti %0 : i3 to i6 -// CHECK-NEXT: = llvm.trunc %0 : i3 to i2 - %3 = trunci %0 : i3 to i2 +func @integer_extension_and_truncation(%arg0 : i3) { +// CHECK-NEXT: = llvm.sext %arg0 : i3 to i6 + %0 = sexti %arg0 : i3 to i6 +// CHECK-NEXT: = llvm.zext %arg0 : i3 to i6 + %1 = zexti %arg0 : i3 to i6 +// CHECK-NEXT: = llvm.trunc %arg0 : i3 to i2 + %2 = trunci %arg0 : i3 to i2 return } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -399,3 +399,32 @@ %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @indexCastOfSignExtend +// CHECK: %[[res:.+]] = index_cast %arg0 : i8 to index +// CHECK: return %[[res]] +func @indexCastOfSignExtend(%arg0: i8) -> index { + %ext = sexti %arg0 : i8 to i16 + %idx = index_cast %ext : i16 to index + return %idx : index +} + +// CHECK-LABEL: @signExtendConstant +// CHECK: %[[cres:.+]] = constant -2 : i16 +// CHECK: return %[[cres]] +func @signExtendConstant() -> i16 { + %c-2 = constant -2 : i8 + %ext = sexti %c-2 : i8 to i16 + return %ext : i16 +} + +// CHECK-LABEL: @truncConstant +// CHECK: %[[cres:.+]] = constant -2 : i16 +// CHECK: return %[[cres]] +func @truncConstant(%arg0: i8) -> i16 { + %c-2 = constant -2 : i32 + %tr = trunci %c-2 : i32 to i16 + return %tr : i16 +}