diff --git a/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h @@ -0,0 +1,59 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_DIALECT_ARITHMETIC_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +namespace mlir { +namespace arith { +struct IntRangeAnalysis : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + ~IntRangeAnalysis() override = default; + + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Infer bounds on loop bounds + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // namespace arith +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -28,6 +28,12 @@ #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc" +//===----------------------------------------------------------------------===// +// Arithmetic Dialect Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" + //===----------------------------------------------------------------------===// // Arithmetic Dialect Operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt @@ -3,3 +3,4 @@ mlir_tablegen(ArithmeticOpsEnums.cpp.inc -gen-enum-defs) add_mlir_dialect(ArithmeticOps arith) add_mlir_doc(ArithmeticOps ArithmeticOps Dialects/ -gen-dialect-doc) +add_mlir_interface(InferIntRangeInterface) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h @@ -0,0 +1,96 @@ +//===- InferIntRangeInterface.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H +#define MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include + +namespace mlir { +class Value; +namespace arith { +/// A collection of integer attributes representing minimum and maximum bounds +/// on a scalar integer value. These bounds are inclusive on both ends, so +/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for +/// the unsigned and signed interpretations of values. +struct IntRangeAttrs { + /// The minimum value of an integer when it is interpreted as unsigned. + IntegerAttr umin; + /// The maximum value of an integer when it is interpreted as unsigned. + IntegerAttr umax; + /// The minimum value of an integer when it is interpreted as signed; + IntegerAttr smin; + /// The maximum value of an integer when it is interpreted as signed. + IntegerAttr smax; + + IntRangeAttrs() : umin(), umax(), smin(), smax() {} + IntRangeAttrs(IntegerAttr umin, IntegerAttr umax, IntegerAttr smin, + IntegerAttr smax) + : umin(umin), umax(umax), smin(smin), smax(smax) { + // Unsigned integers that have upper bounds but no lower bound have a lower + // bound of 0. + if (umax && !umin) { + this->umin = IntegerAttr::get( + umax.getType(), APInt::getZero(umax.getValue().getBitWidth())); + } + } + + IntRangeAttrs(std::tuple urange, + std::tuple srange) + : umin(std::get<0>(urange)), umax(std::get<1>(urange)), + smin(std::get<0>(srange)), smax(std::get<1>(srange)) {} + + /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned + /// minimum and `max` is both the signed and unsigned maximum. + static IntRangeAttrs range(IntegerAttr min, IntegerAttr max); + + /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// to `smin` and `smax`, where the unsigned bounds are constructed from the + /// signed ones if the signed bounds correspond to a contiguous range of + /// values within [0, int_max] or are left unset otherwise. + static IntRangeAttrs fromSigned(IntegerAttr smin, IntegerAttr smax); + static IntRangeAttrs fromSigned(std::tuple srange); + + /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// to `umin` and `umax` and the signed part equal to `umin` and `umax` + /// unless the sign bit changes between the minimum and maximum. + static IntRangeAttrs fromUnsigned(IntegerAttr umin, IntegerAttr umax); + static IntRangeAttrs + fromUnsigned(std::tuple urange); + + /// Returns the union (computed separately for signed and unsigned bounds) + /// of `a` and `b`. + static IntRangeAttrs join(const IntRangeAttrs &a, const IntRangeAttrs &b); + static IntRangeAttrs getPessimisticValueState(MLIRContext *context); + static IntRangeAttrs getPessimisticValueState(Value v); + + bool operator==(const IntRangeAttrs &other) const { + return umin == other.umin && umax == other.umax && smin == other.smin && + smax == other.smax; + } + + /// If either the signed or unsigned interpretations of the range + /// indicate that the value it bounds is a constant, return that constant + /// value. + Optional getConstantValue() const; + + friend raw_ostream &operator<<(raw_ostream &os, const IntRangeAttrs &range); +}; +} // namespace arith +} // namespace mlir + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h.inc" + +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td @@ -0,0 +1,49 @@ +//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===-----------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE +#define MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its arguments. + For each result value, the function should insert an `IntRangeAttrs` + representing the result's minimum and maximum value into `resultRanges`. + If a result does not have integer type, or if bounds can not be inferred, + the pair corresponding to it should be (nullptr, nullptr). + + The returned attributes, if present, should have the same type as the + corresponding result. + + This function will only be called when at least one result of the op is a + scalar integer value. + + `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS + order. Non-integer arguments have `(nullptr, nullptr)` in their `argRanges` + element. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::arith::IntRangeAttrs>":$argRanges, + "::llvm::SmallVectorImpl<::mlir::arith::IntRangeAttrs> &":$resultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -26,6 +26,9 @@ /// Create a pass to legalize Arithmetic ops for LLVM lowering. std::unique_ptr createArithmeticExpandOpsPass(); +/// Create a pass to constant fold based on the results of range inferrence. +std::unique_ptr createArithmeticFoldInferredConstantsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -33,4 +33,9 @@ let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; } +def ArithmeticFoldInferredConstants : Pass<"arith-fold-inferred-constants"> { + let summary = "Constant fold based on the results of integer range inference"; + let constructor = "mlir::arith::createArithmeticFoldInferredConstantsPass()"; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Analysis/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRArithmeticAnalysis + IntRangeAnalysis.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic + + DEPENDS + mlir-headers + + LINK_LIBS PUBLIC + MLIRArithmetic + MLIRAnalysis + MLIRControlFlowInterfaces + MLIRLoopLikeInterface + ) diff --git a/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,182 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith" + +namespace { +using namespace mlir; +using namespace mlir::arith; + +IntegerAttr getLoopBoundFromFold(Optional loopBound, + IntRangeAnalysis &analysis, bool getUpper) { + if (!loopBound.hasValue()) + return {}; + IntegerAttr ret; + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + ret = bound; + } else if (loopBound->is()) { + LatticeElement *result = + analysis.lookupLatticeElement(loopBound->get()); + if (result) + ret = getUpper ? result->getValue().smax : result->getValue().smin; + } + // Loop bounds don't include the upper index, but integer range bounds do + if (ret && getUpper) { + // Note: loops that don't execute (ex. %i = 0 to 0) will create bad bounds + // with this method, but they don't execute so it doesn't matter + ret = IntegerAttr::get(ret.getType(), ret.getValue() - 1); + } + return ret; +} +} // end namespace + +namespace mlir { +namespace arith { + +ChangeResult IntRangeAnalysis::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult ret = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + bool hasYieldedResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + ret |= markAllPessimisticFixpoint(v); + for (Operation *user : v.getUsers()) + hasYieldedResult |= user->hasTrait(); + } + if (!hasIntegerResult) + return ret; + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + SmallVector resultRanges; + resultRanges.reserve(op->getNumResults()); + inferrable.inferResultRanges(argRanges, resultRanges); + assert(resultRanges.size() == op->getNumResults() && + "Range inference should provide one value per result"); + for (auto pair : llvm::zip(op->getResults(), resultRanges)) { + LLVM_DEBUG(llvm::dbgs() << "Result range " << std::get<1>(pair) << "\n"); + LatticeElement &lattice = + getLatticeElement(std::get<0>(pair)); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + ret |= lattice.join(std::get<1>(pair)); + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't. + if (hasYieldedResult && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + ret |= lattice.join(IntRangeAttrs()); + ret |= lattice.markPessimisticFixpoint(); + } + } + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + ret |= markAllPessimisticFixpoint(op->getResults()); + } + return ret; +} + +LogicalResult IntRangeAnalysis::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + SmallVector inferredConsts(llvm::map_range( + operands, [](LatticeElement *range) -> Attribute { + Optional mbConstValue = + range->getValue().getConstantValue(); + if (mbConstValue) + return *mbConstValue; + return {}; + })); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void IntRangeAnalysis::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + SmallVector inferredConsts(llvm::map_range( + operands, [](LatticeElement *range) -> Attribute { + Optional mbConstValue = + range->getValue().getConstantValue(); + if (mbConstValue) + return *mbConstValue; + return {}; + })); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult IntRangeAnalysis::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) + return ForwardDataFlowAnalysis< + IntRangeAttrs>::visitNonControlFlowArguments(op, region, operands); + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + IntegerAttr min = + getLoopBoundFromFold(lowerBound, *this, /*getUpper=*/false); + IntegerAttr max = + getLoopBoundFromFold(upperBound, *this, /*getUpper=*/true); + IntegerAttr stepVal = getLoopBoundFromFold(step, *this, /*getUpper=*/false); + if (stepVal && stepVal.getValue().isNegative()) { + std::swap(min, max); + // Undo the upper bound correction we needed in the positive case. + min = IntegerAttr::get(min.getType(), min.getValue() + 1); + } + + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join(IntRangeAttrs::fromSigned(min, max)); + } + return ForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, region, operands); +} +} // namespace arith +} // namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -5,12 +5,14 @@ add_mlir_dialect_library(MLIRArithmetic ArithmeticOps.cpp ArithmeticDialect.cpp + InferIntRangeInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic DEPENDS MLIRArithmeticOpsIncGen + MLIRInferIntRangeInterfaceIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp @@ -0,0 +1,129 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp.inc" +#include "mlir/IR/BuiltinAttributes.h" + +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace arith { + +IntRangeAttrs IntRangeAttrs::range(IntegerAttr min, IntegerAttr max) { + return {min, max, min, max}; +} + +IntRangeAttrs IntRangeAttrs::fromSigned(IntegerAttr smin, IntegerAttr smax) { + IntegerAttr umin, umax; + if (smin && smax) { + APInt sminVal = smin.getValue(); + APInt smaxVal = smax.getValue(); + if (sminVal.isNonNegative() == smaxVal.isNonNegative()) { + umin = sminVal.ult(smaxVal) ? smin : smax; + umax = sminVal.ugt(smaxVal) ? smin : smax; + } + } else if (smin) { + if (smin.getValue().isNonNegative()) { + umin = smin; + umax = IntegerAttr::get( + smin.getType(), + APInt::getSignedMaxValue(smin.getValue().getBitWidth())); + } + } else if (smax) { + if (smax.getValue().isNegative()) { + umin = smax; + umax = IntegerAttr::get( + smax.getType(), + APInt::getSignedMinValue(smax.getValue().getBitWidth())); + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs +IntRangeAttrs::fromSigned(std::tuple srange) { + return IntRangeAttrs::fromSigned(std::get<0>(srange), std::get<1>(srange)); +} + +IntRangeAttrs IntRangeAttrs::fromUnsigned(IntegerAttr umin, IntegerAttr umax) { + IntegerAttr smin, smax; + // Note, unlike in fromSigned, we have no case for umin <= x < inf, + // since the upper bound is pontentially a negative number. + if (umin && umax) { + APInt uminVal = umin.getValue(); + APInt umaxVal = umax.getValue(); + if (uminVal.isNonNegative() == umaxVal.isNonNegative()) { + smin = uminVal.slt(umaxVal) ? umin : umax; + smax = uminVal.sgt(umaxVal) ? umin : umax; + } + } else if (umax) { + if (umax.getValue().isNonNegative()) { + smax = umax; + smin = IntegerAttr::get(umax.getType(), + APInt::getZero(umax.getValue().getBitWidth())); + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs +IntRangeAttrs::fromUnsigned(std::tuple urange) { + return IntRangeAttrs::fromUnsigned(std::get<0>(urange), std::get<1>(urange)); +} + +IntRangeAttrs IntRangeAttrs::join(const IntRangeAttrs &a, + const IntRangeAttrs &b) { + IntegerAttr umin, umax, smin, smax; + if (a.umin && b.umin) + umin = a.umin.getValue().ult(b.umin.getValue()) ? a.umin : b.umin; + if (a.umax && b.umax) + umax = a.umax.getValue().ugt(b.umax.getValue()) ? a.umax : b.umax; + if (a.smin && b.smin) + smin = a.smin.getValue().slt(b.smin.getValue()) ? a.smin : b.smin; + if (a.smax && b.smax) + smax = a.smax.getValue().sgt(b.smax.getValue()) ? a.smax : b.smax; + + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(MLIRContext *context) { + return {}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(Value v) { + return getPessimisticValueState(v.getContext()); +} + +Optional IntRangeAttrs::getConstantValue() const { + if (umin && umax && umin.getValue() == umax.getValue()) + return umin; + if (smin && smax && smin.getValue() == smax.getValue()) + return smin; + return None; +} + +raw_ostream &operator<<(raw_ostream &os, const IntRangeAttrs &range) { + auto printValIfPresent = [&os](StringRef sign, IntegerAttr val) { + if (val) + os << val.getValue(); + else + os << sign << "inf"; + }; + os << "unsigned : ["; + printValIfPresent("-", range.umin); + os << ", "; + printValIfPresent("+", range.umax); + os << "] signed : ["; + printValIfPresent("-", range.smin); + os << ", "; + printValIfPresent("+", range.smax); + return os << "]"; +} +} // namespace arith +} // namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ExpandOps.cpp + FoldInferredConstants.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms @@ -11,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRArithmeticAnalysis MLIRBufferization MLIRBufferizationTransforms MLIRIR diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/FoldInferredConstants.cpp @@ -0,0 +1,98 @@ +//===- FoldInferredConstants.cpp - Pass to materialize constants that can be +// inferred by range analysis --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/Analysis/IntRangeAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; +using namespace mlir::arith; + +namespace { +/// Patterend after mlir/lib/Transforms/SCCP.cpp +LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, OpBuilder &b, + Value value) { + LatticeElement *mbInferredRange = + analysis.lookupLatticeElement(value); + if (!mbInferredRange) + return failure(); + const IntRangeAttrs &inferredRange = mbInferredRange->getValue(); + Optional mbConstValue = inferredRange.getConstantValue(); + if (!mbConstValue.hasValue()) + return failure(); + Value constant = b.createOrFold(value.getLoc(), *mbConstValue); + value.replaceAllUsesWith(constant); + return success(); +} + +void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= succeeded(replaceWithConstant(analysis, builder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, arg); + } +} + +struct ArithmeticFoldInferredConstantsPass + : public ArithmeticFoldInferredConstantsBase< + ArithmeticFoldInferredConstantsPass> { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + IntRangeAnalysis analysis(ctx); + analysis.run(op); + rewrite(analysis, ctx, op->getRegions()); + } +}; +} // namespace + +std::unique_ptr mlir::arith::createArithmeticFoldInferredConstantsPass() { + return std::make_unique(); +}