diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -29,6 +29,10 @@ /// Create a pass to constant fold based on the results of range inferrence. std::unique_ptr createArithmeticFoldInferredConstantsPass(); +// Create a pass to replace signed ops with unsigned ones when all arguments and +// results are proven non-negative"; +std::unique_ptr createArithmeticUnsignedWhenEquivalentPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -38,4 +38,9 @@ let constructor = "mlir::arith::createArithmeticFoldInferredConstantsPass()"; } +def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> { + let summary = "Replace signed ops with unsigned ones when all arguments and results are proven non-negative"; + let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()"; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ Bufferize.cpp ExpandOps.cpp FoldInferredConstants.cpp + UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -0,0 +1,141 @@ +//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with +// unsigned +// ones when all their arguments and results are statically non-negative --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::arith; + +using OpList = llvm::SmallVector; +// In this pass, we first find all the ops we want to make unsigned and then +// make them unsigned, then we rewrite all the ops in one go so that +// we don't have to keep track of a value mapping in order to do lookups in the +// analysis. + +static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) { + Optional result = analysis.getResult(v); + if (!result.hasValue()) + return false; + const IntRangeAttrs &range = result.getValue(); + return (range.smin() && range.smin()->isNonNegative()); +} + +template +static void getConvertableOps(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](Operation *orig) { + if (isa(orig) && + llvm::all_of(orig->getOperands(), nonNegativePred) && + llvm::all_of(orig->getResults(), nonNegativePred)) { + toRewrite.push_back(orig); + } + }); +} + +static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { + switch (pred) { + case CmpIPredicate::sle: + return CmpIPredicate::ule; + case CmpIPredicate::slt: + return CmpIPredicate::ult; + case CmpIPredicate::sge: + return CmpIPredicate::uge; + case CmpIPredicate::sgt: + return CmpIPredicate::ugt; + default: + return pred; + } +} + +static void getConvertableCmpi(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) { + CmpIPredicate pred = orig.getPredicate(); + if (toUnsignedPred(pred) != pred && + // i1 will spuriously and trivially show up as pontentially negative, + // so don't check the results + llvm::all_of(orig->getOperands(), nonNegativePred)) { + toRewrite.push_back(orig.getOperation()); + } + }); +} + +static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { + OpList ret; + getConvertableOps(root, ret, analysis); + // Since these are in-place changes, they don't need to be topological order + // like the others + getConvertableCmpi(root, ret, analysis); + return ret; +} + +template +static void rewriteOp(Operation *op, OpBuilder &b) { + if (isa(op)) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + Operation *newOp = b.create(op->getLoc(), op->getResultTypes(), + op->getOperands(), op->getAttrs()); + op->replaceAllUsesWith(newOp->getResults()); + op->erase(); + } +} + +static void rewriteCmpI(Operation *op, OpBuilder &b) { + if (auto cmpOp = dyn_cast(op)) { + cmpOp.setPredicateAttr(CmpIPredicateAttr::get( + b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); + } +} + +static void rewrite(Operation *root, const OpList &toReplace) { + OpBuilder b(root->getContext()); + b.setInsertionPoint(root); + for (Operation *op : toReplace) { + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteCmpI(op, b); + } +} + +namespace { +struct ArithmeticUnsignedWhenEquivalentPass + : public ArithmeticUnsignedWhenEquivalentBase< + ArithmeticUnsignedWhenEquivalentPass> { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + IntRangeAnalysis analysis(ctx); + analysis.run(op); + rewrite(op, getMatching(op, analysis)); + } +}; +} // namespace + +std::unique_ptr +mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -arith-unsigned-when-equivalent %s | FileCheck %s + +// CHECK-LABEL func @not_with_maybe_overflow +// CHECK: arith.divsi +// CHECK: arith.ceildivsi +// CHECK: arith.floordivsi +// CHECK: arith.remsi +// CHECK: arith.minsi +// CHECK: arith.maxsi +// CHECK: arith.extsi +// CHECK: arith.cmpi sle +// CHECK: arith.cmpi slt +// CHECK: arith.cmpi sge +// CHECK: arith.cmpi sgt +func @not_with_maybe_overflow(%arg0 : i32) { + %ci32_smax = arith.constant 0x7fffffff : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + return +} + +// CHECK-LABEL func @yes_with_no_overflow +// CHECK: arith.divui +// CHECK: arith.ceildivui +// CHECK: arith.divui +// CHECK: arith.remui +// CHECK: arith.minui +// CHECK: arith.maxui +// CHECK: arith.extui +// CHECK: arith.cmpi ule +// CHECK: arith.cmpi ult +// CHECK: arith.cmpi uge +// CHECK: arith.cmpi ugt +func @yes_with_no_overflow(%arg0 : i32) { + %ci32_almost_smax = arith.constant 0x7ffffffe : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_almost_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + return +} + +// CHECK-LABEL: func @preserves_structure +// CHECK: scf.for %[[arg1:.*]] = +// CHECK: %[[v:.*]] = arith.remui %[[arg1]] +// CHECK: %[[w:.*]] = arith.addi %[[v]], %[[v]] +// CHECK: %[[test:.*]] = arith.cmpi ule, %[[w]] +// CHECK: scf.if %[[test]] +// CHECK: memref.store %[[w]] +func @preserves_structure(%arg0 : memref<8xindex>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + scf.for %arg1 = %c0 to %c8 step %c1 { + %v = arith.remsi %arg1, %c4 : index + %w = arith.addi %v, %v : index + %test = arith.cmpi sle, %w, %c4 : index + scf.if %test { + memref.store %w, %arg0[%arg1] : memref<8xindex> + } + } + return +}