diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h @@ -0,0 +1,44 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl; +} // namespace detail + +class IntRangeAnalysis { +public: + IntRangeAnalysis(MLIRContext *ctx); + ~IntRangeAnalysis(); + IntRangeAnalysis(IntRangeAnalysis &&other); + + /// Analyze all operations rooted under (but not including) + /// `topLevelOperation`. + void run(Operation *topLevelOperation); + /// Get inferred range for value `v` if one exists. + Optional getResult(Value v); + +private: + std::unique_ptr impl; +}; +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -0,0 +1,104 @@ +//===- InferIntRangeInterface.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/APInt.h" +#include + +namespace mlir { +class Value; + +/// A collection of integer attributes representing minimum and maximum bounds +/// on a scalar integer value. These bounds are inclusive on both ends, so +/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for +/// the unsigned and signed interpretations of values. +class IntRangeAttrs { +public: + /// Return a range with no bounds for the signed and unsigned interpretations. + IntRangeAttrs() : uminVal(), umaxVal(), sminVal(), smaxVal() {} + /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax, + /// with llvm::None being considered +-infinity. + IntRangeAttrs(const Optional &umin, const Optional &umax, + const Optional &smin, const Optional &smax) + : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) { + // Unsigned integers that have upper bounds but no lower bound have a lower + // bound of 0. + if (umax && !umin) { + uminVal = APInt::getZero(umax->getBitWidth()); + } + } + + /// Convenience wrapper that allows providing [min, max] ranges. + IntRangeAttrs(const std::tuple, Optional> &urange, + const std::tuple, Optional> &srange) + : IntRangeAttrs(std::get<0>(urange), std::get<1>(urange), + std::get<0>(srange), std::get<1>(srange)) {} + + /// The minimum value of an integer when it is interpreted as unsigned. + const Optional &umin() const; + /// The maximum value of an integer when it is interpreted as unsigned. + const Optional &umax() const; + /// The minimum value of an integer when it is interpreted as signed; + const Optional &smin() const; + /// The maximum value of an integer when it is interpreted as signed. + const Optional &smax() const; + + /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned + /// minimum and `max` is both the signed and unsigned maximum. + static IntRangeAttrs range(const Optional &min, + const Optional &max); + + /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// to `smin` and `smax`, where the unsigned bounds are constructed from the + /// signed ones if the signed bounds correspond to a contiguous range of + /// values within [0, int_max] or are left unset otherwise. + static IntRangeAttrs fromSigned(const Optional &smin, + const Optional &smax); + static IntRangeAttrs + fromSigned(const std::tuple, Optional> &srange); + + /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// to `umin` and `umax` and the signed part equal to `umin` and `umax` + /// unless the sign bit changes between the minimum and maximum. + static IntRangeAttrs fromUnsigned(const Optional &umin, + const Optional &umax); + static IntRangeAttrs + fromUnsigned(const std::tuple, Optional> &urange); + + /// Returns the union (computed separately for signed and unsigned bounds) + /// of `a` and `b`. + static IntRangeAttrs join(const IntRangeAttrs &a, const IntRangeAttrs &b); + /// Helper methods for when IntRangeAttrs is used in a dataflow analysis. + static IntRangeAttrs getPessimisticValueState(MLIRContext *context); + static IntRangeAttrs getPessimisticValueState(Value v); + + /// If either the signed or unsigned interpretations of the range + /// indicate that the value it bounds is a constant, return that constant + /// value. + Optional getConstantValue() const; + + bool operator==(const IntRangeAttrs &other) const; + + friend raw_ostream &operator<<(raw_ostream &os, const IntRangeAttrs &range); + +private: + Optional uminVal, umaxVal, sminVal, smaxVal; +}; +} // namespace mlir + +#include "mlir/Interfaces/InferIntRangeInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td @@ -0,0 +1,47 @@ +//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===-----------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its arguments. + For each result value or block argument (that isn't a branch argument, + since the dataflow analysis handles those case), the method should call + `setValueRange` with that `Value` as an argument. When `setValueRange` + is not called for some value, it will recieve a default value of [-inf, +inf] + (the unbounded range). + + This function will only be called when at least one result of the op is a + scalar integer value or the op has a region. + + `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS + order. Non-integer arguments will have the unbounded range in their + `argRanges` element. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::IntRangeAttrs>":$argRanges, + "::llvm::function_ref":$setResultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -51,6 +51,9 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Create a pass to constant fold based on the results of range inferrence. +std::unique_ptr createFoldInferredConstantsPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -77,6 +77,11 @@ ]; } +def FoldInferredConstants : Pass<"fold-inferred-constants"> { + let summary = "Constant fold based on the results of integer range inference"; + let constructor = "mlir::createFoldInferredConstantsPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -16,6 +17,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -31,6 +33,7 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRViewLikeInterface diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,264 @@ +//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + +public: + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Infer bounds on loop bounds + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // namespace detail +} // namespace mlir + +/// Given the results of getConstant{Lower,Upper}Bound() +/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for +/// that result if possible. +static Optional +getLoopBoundFromFold(Optional loopBound, + detail::IntRangeAnalysisImpl &analysis, bool getUpper) { + if (!loopBound.hasValue()) + return {}; + Optional result; + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + result = bound.getValue(); + } else if (loopBound->is()) { + LatticeElement *lattice = + analysis.lookupLatticeElement(loopBound->get()); + if (lattice != nullptr) + result = + getUpper ? lattice->getValue().smax() : lattice->getValue().smin(); + } + return result; +} + +ChangeResult detail::IntRangeAnalysisImpl::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult result = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + result |= markAllPessimisticFixpoint(v); + } + if (!hasIntegerResult) + return result; + + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->isUninitialized() ? IntRangeAttrs() : val->getValue(); + })); + + auto joinCallback = [&](Value v, const IntRangeAttrs &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + result |= markAllPessimisticFixpoint(op->getResults()); + } + return result; +} + +LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + + ChangeResult result = ChangeResult::NoChange; + auto joinCallback = [&](Value v, const IntRangeAttrs &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + return result; + } + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) + return ForwardDataFlowAnalysis< + IntRangeAttrs>::visitNonControlFlowArguments(op, region, operands); + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + Optional min = + getLoopBoundFromFold(lowerBound, *this, /*getUpper=*/false); + Optional max = + getLoopBoundFromFold(upperBound, *this, /*getUpper=*/true); + Optional stepVal = + getLoopBoundFromFold(step, *this, /*getUpper=*/false); + + if (stepVal && stepVal.getValue().isNegative()) + std::swap(min, max); + else + // Correct the upper bound by subtracting 1 so that it becomes a <= bound, + // because loops do not generally include their upper bound. + max = *max - 1; + + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join(IntRangeAttrs::fromSigned(min, max)); + } + return ForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, region, operands); +} + +IntRangeAnalysis::IntRangeAnalysis(MLIRContext *ctx) { + impl = std::make_unique(ctx); +} + +IntRangeAnalysis::~IntRangeAnalysis() = default; +IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; + +void IntRangeAnalysis::run(Operation *topLevelOperation) { + impl->run(topLevelOperation); +} + +Optional IntRangeAnalysis::getResult(Value v) { + LatticeElement *result = impl->lookupLatticeElement(v); + if (result == nullptr || result->isUninitialized()) + return llvm::None; + return result->getValue(); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp @@ -35,6 +36,7 @@ add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -0,0 +1,127 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" + +using namespace mlir; + +const Optional &IntRangeAttrs::umin() const { return uminVal; } +const Optional &IntRangeAttrs::umax() const { return umaxVal; } +const Optional &IntRangeAttrs::smin() const { return sminVal; } +const Optional &IntRangeAttrs::smax() const { return smaxVal; } + +IntRangeAttrs IntRangeAttrs::range(const Optional &min, + const Optional &max) { + return {min, max, min, max}; +} + +IntRangeAttrs IntRangeAttrs::fromSigned(const Optional &smin, + const Optional &smax) { + Optional umin, umax; + if (smin && smax) { + if (smin->isNonNegative() == smax->isNonNegative()) { + umin = smin->ult(*smax) ? smin : smax; + umax = smin->ugt(*smax) ? smin : smax; + } + } else if (smin) { + if (smin.getValue().isNonNegative()) { + umin = smin; + umax = APInt::getSignedMaxValue(smin->getBitWidth()); + } + } else if (smax) { + if (smax.getValue().isNegative()) { + umin = smax; + umax = APInt::getSignedMinValue(smax->getBitWidth()); + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::fromSigned( + const std::tuple, Optional> &srange) { + return IntRangeAttrs::fromSigned(std::get<0>(srange), std::get<1>(srange)); +} + +IntRangeAttrs IntRangeAttrs::fromUnsigned(const Optional &umin, + const Optional &umax) { + Optional smin, smax; + // Note, unlike in fromSigned, we have no case for umin <= x < inf, + // since the upper bound is pontentially a negative number. + if (umin && umax) { + if (umin->isNonNegative() == umax->isNonNegative()) { + smin = umin->slt(*umax) ? umin : umax; + smax = umin->sgt(*umax) ? umin : umax; + } + } else if (umax) { + if (umax->isNonNegative()) { + smax = umax; + smin = APInt::getZero(umax->getBitWidth()); + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::fromUnsigned( + const std::tuple, Optional> &urange) { + return IntRangeAttrs::fromUnsigned(std::get<0>(urange), std::get<1>(urange)); +} + +IntRangeAttrs IntRangeAttrs::join(const IntRangeAttrs &a, + const IntRangeAttrs &b) { + Optional umin, umax, smin, smax; + if (a.umin() && b.umin()) + umin = a.umin()->ult(*b.umin()) ? a.umin() : b.umin(); + if (a.umax() && b.umax()) + umax = a.umax()->ugt(*b.umax()) ? a.umax() : b.umax(); + if (a.smin() && b.smin()) + smin = a.smin()->slt(*b.smin()) ? a.smin() : b.smin(); + if (a.smax() && b.smax()) + smax = a.smax()->sgt(*b.smax()) ? a.smax() : b.smax(); + + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(MLIRContext *context) { + return {}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(Value v) { + return getPessimisticValueState(v.getContext()); +} + +Optional IntRangeAttrs::getConstantValue() const { + if (umin() && umax() && umin() == umax()) + return umin(); + if (smin() && smax() && smin() == smax()) + return smin(); + return None; +} + +bool IntRangeAttrs::operator==(const IntRangeAttrs &other) const { + return umin() == other.umin() && umax() == other.umax() && + smin() == other.smin() && smax() == other.smax(); +} + +raw_ostream &mlir::operator<<(raw_ostream &os, const IntRangeAttrs &range) { + auto printValIfPresent = [&os](StringRef sign, const Optional &val) { + if (val) + os << val.getValue(); + else + os << sign << "inf"; + }; + os << "unsigned : ["; + printValIfPresent("-", range.umin()); + os << ", "; + printValIfPresent("+", range.umax()); + os << "] signed : ["; + printValIfPresent("-", range.smin()); + os << ", "; + printValIfPresent("+", range.smax()); + return os << "]"; +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Canonicalizer.cpp ControlFlowSink.cpp CSE.cpp + FoldInferredConstants.cpp Inliner.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -23,7 +24,9 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRCopyOpInterface + MLIRInferIntRangeInterface MLIRLoopLikeInterface + MLIRSideEffectInterfaces MLIRPass MLIRSupport MLIRTransformUtils diff --git a/mlir/lib/Transforms/FoldInferredConstants.cpp b/mlir/lib/Transforms/FoldInferredConstants.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/FoldInferredConstants.cpp @@ -0,0 +1,104 @@ +//===- FoldInferredConstants.cpp - Pass to materialize constants that can be +// inferred by range analysis --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +/// Patterend after SCCP.cpp +static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, + OpBuilder &b, OperationFolder &folder, + Value value) { + Optional maybeInferredRange = analysis.getResult(value); + if (!maybeInferredRange) + return failure(); + const IntRangeAttrs &inferredRange = maybeInferredRange.getValue(); + Optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.hasValue()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + +static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + OperationFolder folder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= + succeeded(replaceWithConstant(analysis, builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, folder, arg); + } +} + +namespace { +struct FoldInferredConstantsPass + : public FoldInferredConstantsBase { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + IntRangeAnalysis analysis(ctx); + analysis.run(op); + rewrite(analysis, ctx, op->getRegions()); + } +}; +} // namespace + +std::unique_ptr mlir::createFoldInferredConstantsPass() { + return std::make_unique(); +}