diff --git a/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h @@ -0,0 +1,59 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_DIALECT_ARITHMETIC_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +namespace mlir { +namespace arith { +struct IntRangeAnalysis : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + ~IntRangeAnalysis() override = default; + + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Infer bounds on loop bounds + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // namespace arith +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -26,6 +26,9 @@ /// Create a pass to legalize Arithmetic ops for LLVM lowering. std::unique_ptr createArithmeticExpandOpsPass(); +/// Create a pass to constant fold based on the results of range inferrence. +std::unique_ptr createArithmeticFoldInferredConstantsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -33,4 +33,9 @@ let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; } +def ArithmeticFoldInferredConstants : Pass<"arith-fold-inferred-constants"> { + let summary = "Constant fold based on the results of integer range inference"; + let constructor = "mlir::arith::createArithmeticFoldInferredConstantsPass()"; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRArithmeticAnalysis + IntRangeAnalysis.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic + + DEPENDS + mlir-headers + + LINK_LIBS PUBLIC + MLIRArithmetic + MLIRAnalysis + MLIRControlFlowInterfaces + MLIRLoopLikeInterface + ) diff --git a/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,175 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith" + +namespace { +using namespace mlir; +using namespace mlir::arith; + +IntegerAttr getLoopBoundFromFold(Optional loopBound, + IntRangeAnalysis &analysis, bool getUpper) { + if (!loopBound.hasValue()) + return {}; + IntegerAttr ret; + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + ret = bound; + } else if (loopBound->is()) { + LatticeElement *result = + analysis.lookupLatticeElement(loopBound->get()); + if (result) + ret = getUpper ? result->getValue().second : result->getValue().first; + } + // Loop bounds don't include the upper index, but integer range bounds do + if (ret && getUpper) { + // Note: loops that don't execute (ex. %i = 0 to 0) will create bad bounds + // with this method, but they don't execute so it doesn't matter + ret = IntegerAttr::get(ret.getType(), ret.getValue() - 1); + } + return ret; +} +} // end namespace + +namespace mlir { +namespace arith { + +ChangeResult IntRangeAnalysis::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult ret = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + bool hasYieldedResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + ret |= markAllPessimisticFixpoint(v); + for (Operation *user : v.getUsers()) + hasYieldedResult |= user->hasTrait(); + } + if (!hasIntegerResult) + return ret; + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + SmallVector resultRanges; + resultRanges.reserve(op->getNumResults()); + inferrable.inferResultRanges(argRanges, resultRanges); + assert(resultRanges.size() == op->getNumResults() && + "Range inference should provide one value per result"); + for (auto pair : llvm::zip(op->getResults(), resultRanges)) { + LLVM_DEBUG(llvm::dbgs() << "Result range " << std::get<1>(pair) << "\n"); + LatticeElement &lattice = + getLatticeElement(std::get<0>(pair)); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + ret |= lattice.join(std::get<1>(pair)); + // Catch loop results with loop variant bounds and conservatively make + // them + // [-inf, inf] so we don't circle around infinitely often (because the + // dataflow analysis in MLIR doesn't attempt to work out trip counts and + // often can't) + if (hasYieldedResult && oldRange.hasValue() && + lattice.getValue() != *oldRange) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + ret |= lattice.join({{}, {}}); + ret |= lattice.markPessimisticFixpoint(); + } + } + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + ret |= markAllPessimisticFixpoint(op->getResults()); + } + return ret; +} + +LogicalResult IntRangeAnalysis::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + SmallVector inferredConsts(llvm::map_range( + operands, [](LatticeElement *range) -> Attribute { + IntegerAttr min, max; + std::tie(min, max) = range->getValue(); + if (min == max) + return min; + return {}; + })); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void IntRangeAnalysis::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + SmallVector inferredConsts(llvm::map_range( + operands, [](LatticeElement *range) -> Attribute { + IntegerAttr min, max; + std::tie(min, max) = range->getValue(); + if (min == max) + return min; + return {}; + })); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult IntRangeAnalysis::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) + return ForwardDataFlowAnalysis< + IntRangeAttrs>::visitNonControlFlowArguments(op, region, operands); + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + IntegerAttr min = + getLoopBoundFromFold(lowerBound, *this, /*getUpper=*/false); + IntegerAttr max = + getLoopBoundFromFold(upperBound, *this, /*getUpper=*/true); + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join({min, max}); + } + return ForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, region, operands); +} +} // namespace arith +} // namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ExpandOps.cpp + FoldInferredConstants.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms @@ -11,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRArithmeticAnalysis MLIRBufferization MLIRBufferizationTransforms MLIRIR diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp @@ -0,0 +1,98 @@ +//===- FoldInferredConstants.cpp - Pass to materialize constants that can be +// inferred by range analysis --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; +using namespace mlir::arith; + +namespace { +/// Patterend after mlir/lib/Transforms/SCCP.cpp +LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, OpBuilder &b, + Value value) { + LatticeElement *mbInferredRange = + analysis.lookupLatticeElement(value); + if (!mbInferredRange) + return failure(); + const IntRangeAttrs &inferredRange = mbInferredRange->getValue(); + if (!inferredRange.first || !inferredRange.second || + inferredRange.first != inferredRange.second) + return failure(); + Value constant = + b.createOrFold(value.getLoc(), inferredRange.first); + value.replaceAllUsesWith(constant); + return success(); +} + +void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= succeeded(replaceWithConstant(analysis, builder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, arg); + } +} + +struct ArithmeticFoldInferredConstantsPass + : public ArithmeticFoldInferredConstantsBase< + ArithmeticFoldInferredConstantsPass> { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + IntRangeAnalysis analysis(ctx); + analysis.run(op); + rewrite(analysis, ctx, op->getRegions()); + } +}; +} // namespace + +std::unique_ptr mlir::arith::createArithmeticFoldInferredConstantsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir @@ -0,0 +1,626 @@ +// RUN: mlir-opt -arith-fold-inferred-constants -canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @add_min_max +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func @add_min_max(%a: index, %b: index) -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.minsi %a, %c1 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.minui %b, %c2 : index + %3 = arith.maxui %2, %c2 : index + %4 = arith.addi %1, %3 : index + return %4 : index +} + +// CHECK-LABEL: func @add_lower_bound +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @add_lower_bound(%a : i32, %b : i32) -> i1 { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.maxsi %a, %c1 : i32 + %1 = arith.maxsi %b, %c1 : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.cmpi sge, %2, %c2 : i32 + %4 = arith.cmpi uge, %2, %c2 : i32 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @sub_signed_vs_unsigned +// CHECK-NOT: arith.cmpi sle +// CHECK: %[[unsigned:.*]] = arith.cmpi ule +// CHECK: return %[[unsigned]] : i1 +func @sub_signed_vs_unsigned(%v : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %0 = arith.minsi %v, %c2 : i64 + %1 = arith.subi %0, %c2 : i64 + %2 = arith.cmpi sle, %1, %c0 : i64 + %3 = arith.cmpi ule, %1, %c0 : i64 + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @multiply_negatives +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @multiply_negatives(%a : index, %b : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c_1 = arith.constant -1 : index + %c_2 = arith.constant -2 : index + %c_4 = arith.constant -4 : index + %c_12 = arith.constant -12 : index + %0 = arith.maxsi %a, %c2 : index + %1 = arith.minsi %0, %c3 : index + %2 = arith.minsi %b, %c_1 : index + %3 = arith.maxsi %2, %c_4 : index + %4 = arith.muli %1, %3 : index + %5 = arith.cmpi slt, %4, %c_12 : index + %6 = arith.cmpi slt, %c_1, %4 : index + %7 = arith.ori %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @multiply_unsigned_bounds +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @multiply_unsigned_bounds(%a : i16, %b : i16) -> i1 { + %c0 = arith.constant 0 : i16 + %c4 = arith.constant 4 : i16 + %c_mask = arith.constant 0x3fff : i16 + %c_bound = arith.constant 0xfffc : i16 + %0 = arith.andi %a, %c_mask : i16 + %1 = arith.minui %b, %c4 : i16 + %2 = arith.muli %0, %1 : i16 + %3 = arith.cmpi uge, %2, %c0 : i16 + %4 = arith.cmpi ule, %2, %c_bound : i16 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: @for_loop_with_increasing_arg +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @for_loop_with_increasing_arg() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %c0) -> index { + %10 = arith.addi %arg0, %arg1 : index + scf.yield %10 : index + } + %1 = arith.cmpi ule, %0, %c16 : index + return %1 : i1 +} + +// CHECK-LABEL: @for_loop_with_constant_result +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @for_loop_with_constant_result() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %true) -> i1 { + %10 = arith.cmpi ule, %arg0, %c4 : index + %11 = arith.andi %10, %arg1 : i1 + scf.yield %11 : i1 + } + return %0 : i1 +} + +// CHECK-LABEL: func @div_bounds_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_positive(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %0 = arith.maxsi %arg0, %c2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sge, %1, %c0 : index + %4 = arith.cmpi sle, %1, %c2 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.cmpi sle, %1, %c2 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + %9 = arith.andi %8, %6 : i1 + return %9 : i1 +} + +// CHECK-LABEL: func @div_bounds_negative +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_negative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c_2 = arith.constant -2 : index + %c4 = arith.constant 4 : index + %0 = arith.minsi %arg0, %c_2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sle, %1, %c0 : index + %4 = arith.cmpi sge, %1, %c_2 : index + %5 = arith.cmpi eq, %2, %c0 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @div_zero_undefined +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @div_zero_undefined(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.divui %c4, %0 : index + %2 = arith.cmpi ule, %1, %c4 : index + return %2 : i1 +} + +// CHECK-LABEL: func @ceil_divui +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divui(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c1 : index + %2 = arith.ceildivui %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + + %4 = arith.maxui %0, %c0 : index + %5 = arith.ceildivui %4, %c4 : index + %6 = arith.cmpi eq, %5, %c1 : index + %7 = arith.andi %3, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @ceil_divsi +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divsi(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.ceildivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + %4 = arith.ceildivsi %1, %c-4 : index + %5 = arith.cmpi eq, %4, %c0 : index + %6 = arith.andi %3, %5 : i1 + + %7 = arith.maxsi %0, %c0 : index + %8 = arith.ceildivsi %7, %c4 : index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = arith.andi %6, %9 : i1 + return %10 : i1 +} + +// CHECK-LABEL: func @floor_divsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @floor_divsi(%arg0 : index) -> i1 { + %c4 = arith.constant 4 : index + %c-1 = arith.constant -1 : index + %c-3 = arith.constant -3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c-1 : index + %1 = arith.maxsi %0, %c-4 : index + %2 = arith.floordivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c-1 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remui_base +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_base(%arg0 : index, %arg1 : index ) -> i1 { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg1, %c4 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %arg0, %1 : index + %3 = arith.cmpi ult, %2, %c4 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_base +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @remsi_base(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %arg0, %1 : index + %3 = arith.cmpi sgt, %2, %c-4 : index + %4 = arith.cmpi slt, %2, %c4 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.andi %3, %4 : i1 + %7 = arith.andi %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @remsi_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_positive(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.maxsi %arg0, %c0 : index + %3 = arith.remsi %2, %1 : index + %4 = arith.cmpi sge, %3, %c0 : index + %5 = arith.cmpi slt, %3, %c4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @remui_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remsi_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remui_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remui_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + + %0 = arith.minui %arg0, %c5 : index + %1 = arith.maxui %0, %c3 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remsi_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c3 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @andi +// CHECK: %[[ret:.*]] = arith.cmpi ugt +// CHECK: return %[[ret]] +func @andi(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %c7 = arith.constant 7 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.andi %1, %c7 : index + %3 = arith.cmpi ugt, %2, %c5 : index + %4 = arith.cmpi ule, %2, %c7 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @andi_doesnt_make_nonnegative +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @andi_doesnt_make_nonnegative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = arith.addi %arg0, %c1 : index + %1 = arith.andi %arg0, %0 : index + %2 = arith.cmpi sge, %1, %c0 : index + return %2 : i1 +} + + +// CHECK-LABEL: func @ori +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @ori(%arg0 : i128, %arg1 : i128) -> i1 { + %c-1 = arith.constant -1 : i128 + %c0 = arith.constant 0 : i128 + + %0 = arith.minsi %arg1, %c-1 : i128 + %1 = arith.ori %arg0, %0 : i128 + %2 = arith.cmpi slt, %1, %c0 : i128 + return %2 : i1 +} + +// CHECK-LABEL: func @xori +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @xori(%arg0 : i64, %arg1 : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c7 = arith.constant 7 : i64 + %c15 = arith.constant 15 : i64 + %true = arith.constant true + + %0 = arith.minui %arg0, %c7 : i64 + %1 = arith.minui %arg1, %c15 : i64 + %2 = arith.xori %0, %1 : i64 + %3 = arith.cmpi sle, %2, %c15 : i64 + %4 = arith.xori %3, %true : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @extui +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extui(%arg0 : i16) -> i1 { + %ci16_max = arith.constant 0xffff : i32 + %0 = arith.extui %arg0 : i16 to i32 + %1 = arith.cmpi ule, %0, %ci16_max : i32 + return %1 : i1 +} + +// CHECK-LABEL: func @extsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extsi(%arg0 : i16) -> i1 { + %ci16_smax = arith.constant 0x7fff : i32 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.extsi %arg0 : i16 to i32 + %1 = arith.cmpi sle, %0, %ci16_smax : i32 + %2 = arith.cmpi sge, %0, %ci16_smin : i32 + %3 = arith.andi %1, %2 : i1 + return %3 : i1 +} + +// CHECK-LABEL: func @trunci +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @trunci(%arg0 : i32) -> i1 { + %c-14_i32 = arith.constant -14 : i32 + %c-14_i16 = arith.constant -14 : i16 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.minsi %arg0, %c-14_i32 : i32 + %1 = arith.trunci %0 : i32 to i16 + %2 = arith.cmpi sle, %1, %c-14_i16 : i16 + %3 = arith.extsi %1 : i16 to i32 + %4 = arith.cmpi sle, %3, %c-14_i32 : i32 + %5 = arith.cmpi sge, %3, %ci16_smin : i32 + %6 = arith.andi %2, %4 : i1 + %7 = arith.andi %6, %5 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @index_cast +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @index_cast(%arg0 : index) -> i1 { + %ci32_smin = arith.constant 0xffffffff80000000 : i64 + %0 = arith.index_cast %arg0 : index to i32 + %1 = arith.index_cast %0 : i32 to index + %2 = arith.index_cast %ci32_smin : i64 to index + %3 = arith.cmpi sge, %1, %2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @shli +// CHECK: %[[ret:.*]] = arith.cmpi sgt +// CHECK: return %[[ret]] +func @shli(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-1 = arith.constant -1 : i32 + %c-16 = arith.constant -16 : i32 + %0 = arith.maxsi %arg0, %c-1 : i32 + %1 = arith.minsi %0, %c2 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shli %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-16 : i32 + %5 = arith.cmpi sle, %3, %c32 : i32 + %6 = arith.cmpi sgt, %3, %c8 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @shrui +// CHECK: %[[ret:.*]] = arith.cmpi uge +// CHECK: return %[[ret]] +func @shrui(%arg0 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %0 = arith.select %arg0, %c2, %c4 : i32 + %1 = arith.shrui %c32, %0 : i32 + %2 = arith.cmpi ule, %1, %c8 : i32 + %3 = arith.cmpi uge, %1, %c2 : i32 + %4 = arith.cmpi uge, %1, %c8 : i32 + %5 = arith.andi %2, %3 : i1 + %6 = arith.andi %5, %4 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @shrsi +// CHECK: %[[ret:.*]] = arith.cmpi slt +// CHECK: return %[[ret]] +func @shrsi(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-8 = arith.constant -8 : i32 + %c-32 = arith.constant -32 : i32 + %0 = arith.maxsi %arg0, %c-32 : i32 + %1 = arith.minsi %0, %c32 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shrsi %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-8 : i32 + %5 = arith.cmpi sle, %3, %c8 : i32 + %6 = arith.cmpi slt, %3, %c2 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @no_aggressive_eq +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @no_aggressive_eq(%arg0 : index) -> i1 { + %c1 = arith.constant 1 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.minui %arg0, %c1 : index + %2 = arith.cmpi eq, %0, %1 : index + return %2 : i1 +} + +// CHECK-LABEL: func @select_union +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] + +func @select_union(%arg0 : index, %arg1 : i1) -> i1 { + %c64 = arith.constant 64 : index + %c100 = arith.constant 100 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %0 = arith.remui %arg0, %c64 : index + %1 = arith.addi %0, %c128 : index + %2 = arith.select %arg1, %0, %1 : index + %3 = arith.cmpi slt, %2, %c192 : index + %4 = arith.cmpi ne, %c100, %2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @if_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @if_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + %1 = scf.if %arg1 -> index { + %10 = arith.muli %0, %0 : index + scf.yield %10 : index + } else { + %20 = arith.muli %0, %c-1 : index + scf.yield %20 : index + } + %2 = arith.cmpi sle, %1, %c16 : index + %3 = arith.cmpi sge, %1, %c-4 : index + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @branch_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @branch_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1 : + %1 = arith.muli %0, %0 : index + cf.br ^bb3(%1 : index) +^bb2 : + %2 = arith.muli %0, %c-1 : index + cf.br ^bb3(%2 : index) +^bb3(%3 : index) : + %4 = arith.cmpi sle, %3, %c16 : index + %5 = arith.cmpi sge, %3, %c-4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @loop_bound_not_inferred_with_branch +// CHECK-DAG: %[[min:.*]] = arith.cmpi sge +// CHECK-DAG: %[[max:.*]] = arith.cmpi slt +// CHECK-DAG: %[[ret:.*]] = arith.andi %[[min]], %[[max]] +// CHECK: return %[[ret]] +func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.minui %arg0, %c4 : index + cf.br ^bb2(%c0 : index) +^bb1(%1 : index) : + %2 = arith.addi %1, %c1 : index + cf.br ^bb2(%2 : index) +^bb2(%3 : index): + %4 = arith.cmpi ult, %3, %c4 : index + cf.cond_br %4, ^bb1(%3 : index), ^bb3(%3 : index) +^bb3(%5 : index) : + %6 = arith.cmpi sge, %5, %c0 : index + %7 = arith.cmpi slt, %5, %c4 : index + %8 = arith.andi %6, %7 : i1 + return %8 : i1 +} +