diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -26,6 +26,10 @@ /// Create a pass to legalize Arithmetic ops for LLVM lowering. std::unique_ptr createArithmeticExpandOpsPass(); +/// Create a pass to replace signed ops with unsigned ones where they are proven +/// equivalent. +std::unique_ptr createArithmeticUnsignedWhenEquivalentPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -33,4 +33,20 @@ let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; } +def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> { + let summary = "Replace signed ops with unsigned ones where they are proven equivalent"; + let description = [{ + Replace signed ops with their unsigned equivalents when integer range analysis + determines that their arguments and results are all guaranteed to be + non-negative when interpreted as signed integers. When this occurs, + we know that the semantics of the signed and unsigned operations are the same, + since they share the same behavior when their operands and results are in the + range [0, signed_max(type)]. + + The affect ops include division, remainder, shifts, min, max, and integer + comparisons. + }]; + let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()"; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ExpandOps.cpp + UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms @@ -10,6 +11,7 @@ MLIRArithmeticTransformsIncGen LINK_LIBS PUBLIC + MLIRAnalysis MLIRArithmetic MLIRBufferization MLIRBufferizationTransforms diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -0,0 +1,144 @@ +//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with +// unsigned +// ones when all their arguments and results are statically non-negative --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::arith; + +using OpList = llvm::SmallVector; + +/// Returns true when a value is statically non-negative in that it has a lower +/// bound on its value (if it is treated as signed) and that bound is +/// non-negative. +static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) { + Optional result = analysis.getResult(v); + if (!result.hasValue()) + return false; + const ConstantIntRanges &range = result.getValue(); + return (range.smin().isNonNegative()); +} + +/// Identify all operations in a block that have signed equivalents and have +/// operands and results that are statically non-negative. +template +static void getConvertableOps(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](Operation *orig) { + if (isa(orig) && + llvm::all_of(orig->getOperands(), nonNegativePred) && + llvm::all_of(orig->getResults(), nonNegativePred)) { + toRewrite.push_back(orig); + } + }); +} + +static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { + switch (pred) { + case CmpIPredicate::sle: + return CmpIPredicate::ule; + case CmpIPredicate::slt: + return CmpIPredicate::ult; + case CmpIPredicate::sge: + return CmpIPredicate::uge; + case CmpIPredicate::sgt: + return CmpIPredicate::ugt; + default: + return pred; + } +} + +/// Find all cmpi ops that can be replaced by their unsigned equivalents. +static void getConvertableCmpi(Operation *root, OpList &toRewrite, + IntRangeAnalysis &analysis) { + auto nonNegativePred = [&analysis](Value v) -> bool { + return staticallyNonNegative(analysis, v); + }; + root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) { + CmpIPredicate pred = orig.getPredicate(); + if (toUnsignedPred(pred) != pred && + // i1 will spuriously and trivially show up as pontentially negative, + // so don't check the results + llvm::all_of(orig->getOperands(), nonNegativePred)) { + toRewrite.push_back(orig.getOperation()); + } + }); +} + +/// Return ops to be replaced in the order they should be rewritten. +static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { + OpList ret; + getConvertableOps(root, ret, analysis); + // Since these are in-place changes, they don't need to be topological order + // like the others. + getConvertableCmpi(root, ret, analysis); + return ret; +} + +template +static void rewriteOp(Operation *op, OpBuilder &b) { + if (isa(op)) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + Operation *newOp = b.create(op->getLoc(), op->getResultTypes(), + op->getOperands(), op->getAttrs()); + op->replaceAllUsesWith(newOp->getResults()); + op->erase(); + } +} + +static void rewriteCmpI(Operation *op, OpBuilder &b) { + if (auto cmpOp = dyn_cast(op)) { + cmpOp.setPredicateAttr(CmpIPredicateAttr::get( + b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); + } +} + +static void rewrite(Operation *root, const OpList &toReplace) { + OpBuilder b(root->getContext()); + b.setInsertionPoint(root); + for (Operation *op : toReplace) { + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteOp(op, b); + rewriteCmpI(op, b); + } +} + +namespace { +struct ArithmeticUnsignedWhenEquivalentPass + : public ArithmeticUnsignedWhenEquivalentBase< + ArithmeticUnsignedWhenEquivalentPass> { + /// Implementation structure: first find all equivalent ops and collect them, + /// then perform all the rewrites in a second pass over the target op. This + /// ensures that analysis results are not invalidated during rewriting. + void runOnOperation() override { + Operation *op = getOperation(); + IntRangeAnalysis analysis(op); + rewrite(op, getMatching(op, analysis)); + } +}; +} // end anonymous namespace + +std::unique_ptr +mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -arith-unsigned-when-equivalent %s | FileCheck %s + +// CHECK-LABEL func @not_with_maybe_overflow +// CHECK: arith.divsi +// CHECK: arith.ceildivsi +// CHECK: arith.floordivsi +// CHECK: arith.remsi +// CHECK: arith.minsi +// CHECK: arith.maxsi +// CHECK: arith.extsi +// CHECK: arith.cmpi sle +// CHECK: arith.cmpi slt +// CHECK: arith.cmpi sge +// CHECK: arith.cmpi sgt +func.func @not_with_maybe_overflow(%arg0 : i32) { + %ci32_smax = arith.constant 0x7fffffff : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + func.return +} + +// CHECK-LABEL func @yes_with_no_overflow +// CHECK: arith.divui +// CHECK: arith.ceildivui +// CHECK: arith.divui +// CHECK: arith.remui +// CHECK: arith.minui +// CHECK: arith.maxui +// CHECK: arith.extui +// CHECK: arith.cmpi ule +// CHECK: arith.cmpi ult +// CHECK: arith.cmpi uge +// CHECK: arith.cmpi ugt +func.func @yes_with_no_overflow(%arg0 : i32) { + %ci32_almost_smax = arith.constant 0x7ffffffe : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %0 = arith.minui %arg0, %ci32_almost_smax : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.divsi %1, %c4 : i32 + %3 = arith.ceildivsi %1, %c4 : i32 + %4 = arith.floordivsi %1, %c4 : i32 + %5 = arith.remsi %1, %c4 : i32 + %6 = arith.minsi %1, %c4 : i32 + %7 = arith.maxsi %1, %c4 : i32 + %8 = arith.extsi %1 : i32 to i64 + %9 = arith.cmpi sle, %1, %c4 : i32 + %10 = arith.cmpi slt, %1, %c4 : i32 + %11 = arith.cmpi sge, %1, %c4 : i32 + %12 = arith.cmpi sgt, %1, %c4 : i32 + func.return +} + +// CHECK-LABEL: func @preserves_structure +// CHECK: scf.for %[[arg1:.*]] = +// CHECK: %[[v:.*]] = arith.remui %[[arg1]] +// CHECK: %[[w:.*]] = arith.addi %[[v]], %[[v]] +// CHECK: %[[test:.*]] = arith.cmpi ule, %[[w]] +// CHECK: scf.if %[[test]] +// CHECK: memref.store %[[w]] +func.func @preserves_structure(%arg0 : memref<8xindex>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + scf.for %arg1 = %c0 to %c8 step %c1 { + %v = arith.remsi %arg1, %c4 : index + %w = arith.addi %v, %v : index + %test = arith.cmpi sle, %w, %c4 : index + scf.if %test { + memref.store %w, %arg0[%arg1] : memref<8xindex> + } + } + func.return +}