diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -50,9 +50,9 @@ /// maximally compose chains of AffineApplyOps. FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of the owning op's operands. `dim` must be `nullopt` if and only if -/// `value` is index-typed. +/// Reify a bound for the given integer/index-typed value or shape +/// size in terms of the owning op's operands. `dim` must be `nullopt` if and +/// only if `value` is integer/index-typed. The reified result has index type. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. @@ -61,9 +61,10 @@ std::optional dim, bool closedUB = false); -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of SSA values for which `stopCondition` is met. `dim` must be -/// `nullopt` if and only if `value` is index-typed. +/// Reify a bound for the given integer/index-typed value or shape dimension +/// size in terms of SSA values for which `stopCondition` is met. `dim` must be +/// `nullopt` if and only if `value` is integer/index-typed. The reified result +/// has index type. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h @@ -9,12 +9,14 @@ #ifndef MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_ARITH_TRANSFORMS_TRANSFORMS_H +#include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" namespace mlir { class Location; class OpBuilder; class OpFoldResult; +class Type; class Value; namespace presburger { @@ -23,20 +25,29 @@ namespace arith { -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of the owning op's operands. `dim` must be `nullopt` if and only if -/// `value` is index-typed. +/// Reify a bound for the given integer/index-typed value or shape dimension +/// size in terms of the owning op's operands. `dim` must be `nullopt` if and +/// only if `value` is integer/index-typed. +/// +/// The type of the reified result (IntegerType or IndexType) can be specified. +/// By default, it is the same type as the input. (IndexType in case of a +/// dimension). /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, - bool closedUB = false); + bool closedUB = false, + Type resultType = {}); -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of SSA values for which `stopCondition` is met. `dim` must be -/// `nullopt` if and only if `value` is index-typed. +/// Reify a bound for the given integer/index-typed value or shape dimension +/// size in terms of SSA values for which `stopCondition` is met. `dim` must be +/// `nullopt` if and only if `value` is integer/index-typed. +/// +/// The type of the reified result (IntegerType or IndexType) can be specified. +/// By default, it is the same type as the input. (IndexType in case of a +/// dimension). /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. @@ -44,7 +55,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, function_ref)> stopCondition, - bool closedUB = false); + bool closedUB = false, Type resultType = {}); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -23,9 +23,11 @@ using ValueDimList = SmallVector>>; /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a -/// constraint system and mapping of constrained variables to index-typed +/// constraint system and mapping of constrained variables to integer-typed /// values or dimension sizes of shaped values. /// +/// Note: "Integer type" refers to both integer and index types. +/// /// Interface implementations of `ValueBoundsOpInterface` use `addBounds` to /// insert constraints about their results and/or region block arguments into /// the constraint set in the form of an AffineExpr. When a bound should be @@ -39,16 +41,16 @@ class ValueBoundsConstraintSet { public: /// The stop condition when traversing the backward slice of a shaped value/ - /// index-type value. The traversal continues until the stop condition + /// integer-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. using StopConditionFn = function_ref)>; - /// Compute a bound for the given index-typed value or shape dimension size. + /// Compute a bound for the given integer-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are - /// stored in `mapOperands`. An operand is either an index-type SSA value + /// stored in `mapOperands`. An operand is either an integer-type SSA value /// or a shaped value and a dimension. /// - /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound + /// `dim` must be `nullopt` if and only if `value` is integer-typed. The bound /// is computed in terms of values/dimensions for which `stopCondition` /// evaluates to "true". To that end, the backward slice (reverse use-def /// chain) of the given value is visited in a worklist-driven manner and the @@ -80,10 +82,10 @@ std::optional dim, ValueRange independencies, bool closedUB = false); - /// Compute a constant bound for the given index-typed value or shape + /// Compute a constant bound for the given integer-typed value or shape /// dimension size. /// - /// `dim` must be `nullopt` if and only if `value` is index-typed. This + /// `dim` must be `nullopt` if and only if `value` is integer-typed. This /// function traverses the backward slice of the given value in a /// worklist-driven manner until `stopCondition` evaluates to "true". The /// constraint set is populated according to `ValueBoundsOpInterface` for each @@ -102,43 +104,43 @@ StopConditionFn stopCondition = nullptr, bool closedUB = false); - /// Bound the given index-typed value by the given expression. + /// Bound the given integer-typed value by the given expression. void addBound(presburger::BoundType type, Value value, AffineExpr expr); /// Bound the given shaped value dimension by the given expression. void addBound(presburger::BoundType type, Value value, int64_t dim, AffineExpr expr); - /// Bound the given index-typed value by a constant or SSA value. + /// Bound the given integer-typed value by a constant or SSA value. void addBound(presburger::BoundType type, Value value, OpFoldResult ofr); /// Bound the given shaped value dimension by a constant or SSA value. void addBound(presburger::BoundType type, Value value, int64_t dim, OpFoldResult ofr); - /// Return an expression that represents the given index-typed value or shaped - /// value dimension. If this value/dimension was not used so far, it is added - /// to the worklist. + /// Return an expression that represents the given integer-typed value or + /// shaped value dimension. If this value/dimension was not used so far, it is + /// added to the worklist. /// - /// `dim` must be `nullopt` if and only if the given value is of index type. + /// `dim` must be `nullopt` if and only if the given value is of integer type. AffineExpr getExpr(Value value, std::optional dim = std::nullopt); - /// Return an expression that represents a constant or index-typed SSA value. - /// In case of a value, if this value was not used so far, it is added to the - /// worklist. + /// Return an expression that represents a constant or integer-typed SSA + /// value. In case of a value, if this value was not used so far, it is added + /// to the worklist. AffineExpr getExpr(OpFoldResult ofr); /// Return an expression that represents a constant. AffineExpr getExpr(int64_t constant); protected: - /// Dimension identifier to indicate a value is index-typed. - static constexpr int64_t kIndexValue = -1; + /// Dimension identifier to indicate a value is integer-typed. + static constexpr int64_t kIntegerValue = -1; using ValueDim = std::pair; ValueBoundsConstraintSet(ValueDim valueDim); - /// Iteratively process all elements on the worklist until an index-typed + /// Iteratively process all elements on the worklist until an integer-typed /// value or shaped value meets `stopCondition`. Such values are not processed /// any further. void processWorklist(StopConditionFn stopCondition); @@ -149,7 +151,7 @@ /// Return the column position of the given value/dimension. Asserts that the /// value/dimension exists in the constraint set. - int64_t getPos(Value value, int64_t dim = kIndexValue) const; + int64_t getPos(Value value, int64_t dim = kIntegerValue) const; /// Insert a value/dimension into the constraint set. If `isSymbol` is set to /// "false", a dimension is added. diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td @@ -13,24 +13,25 @@ def ValueBoundsOpInterface : OpInterface<"ValueBoundsOpInterface"> { let description = [{ - This interface allows operations with index-typed and/or shaped value-typed - results/block arguments to specify range bounds. These bounds are stored in - a constraint set. The constraint set can then be queried to compute bounds - in terms of other values that are stored in the constraint set. + This interface allows operations with integer/index-typed and/or shaped + value-typed results/block arguments to specify range bounds. These bounds + are stored in a constraint set. The constraint set can then be queried to + compute bounds in terms of other values that are stored in the constraint + set. }]; let cppNamespace = "::mlir"; let methods = [ InterfaceMethod< /*desc=*/[{ - Populate the constraint set with bounds for the given index-typed - value. + Populate the constraint set with bounds for the given + integer/index-typed value. Note: If `value` is a block argument, it must belong to an entry block of a region. Unstructured control flow graphs are not supported at the moment. }], /*retType=*/"void", - /*methodName=*/"populateBoundsForIndexValue", + /*methodName=*/"populateBoundsForIntegerValue", /*args=*/(ins "::mlir::Value":$value, "::mlir::ValueBoundsConstraintSet &":$cstr), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -20,8 +20,8 @@ struct AffineApplyOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto applyOp = cast(op); assert(value == applyOp.getResult() && "invalid value"); assert(applyOp.getAffineMap().getNumResults() == 1 && @@ -42,8 +42,8 @@ struct AffineMinOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto minOp = cast(op); assert(value == minOp.getResult() && "invalid value"); @@ -64,8 +64,8 @@ struct AffineMaxOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto maxOp = cast(op); assert(value == maxOp.getResult() && "invalid value"); diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -51,8 +51,8 @@ std::optional dim = valueDim.second; if (!dim.has_value()) { - // This is an index-typed value. - assert(value.getType().isIndex() && "expected index type"); + // This is an integer-typed value. + assert(value.getType().isIntOrIndex() && "expected integer type"); operands.push_back(value); continue; } diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -20,8 +20,8 @@ struct AddIOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto addIOp = cast(op); assert(value == addIOp.getResult() && "invalid value"); @@ -34,8 +34,8 @@ struct ConstantOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto constantOp = cast(op); assert(value == constantOp.getResult() && "invalid value"); @@ -46,8 +46,8 @@ struct SubIOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto subIOp = cast(op); assert(value == subIOp.getResult() && "invalid value"); @@ -59,8 +59,8 @@ struct MulIOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto mulIOp = cast(op); assert(value == mulIOp.getResult() && "invalid value"); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -17,14 +17,15 @@ using namespace mlir::arith; /// Build Arith IR for the given affine map and its operands. -static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, - ValueRange operands) { +static Value buildArithValue(OpBuilder &b, Location loc, Type type, + AffineMap map, ValueRange operands) { assert(map.getNumResults() == 1 && "multiple results not supported yet"); std::function buildExpr = [&](AffineExpr e) -> Value { switch (e.getKind()) { case AffineExprKind::Constant: - return b.create(loc, - e.cast().getValue()); + return b.create( + loc, type, + b.getIntegerAttr(type, e.cast().getValue())); case AffineExprKind::DimId: return operands[e.cast().getPosition()]; case AffineExprKind::SymbolId: @@ -62,23 +63,28 @@ return buildExpr(map.getResult(0)); } -FailureOr mlir::arith::reifyValueBound(OpBuilder &b, Location loc, - presburger::BoundType type, - Value value, - std::optional dim, - bool closedUB) { +FailureOr mlir::arith::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + std::optional dim, bool closedUB, Type resultType) { auto stopCondition = [&](Value v, std::optional d) { // Reify in terms of SSA values that are different from `value`. return v != value; }; - return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB); + return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB, + resultType); } FailureOr mlir::arith::reifyValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, function_ref)> stopCondition, - bool closedUB) { + bool closedUB, Type resultType) { + // Default result type if not specified: same as value. + if (!resultType) { + resultType = dim.has_value() ? b.getIndexType() : value.getType(); + } + assert(resultType.isIntOrIndex() && "invalid result type"); + // Compute bound. AffineMap boundMap; ValueDimList mapOperands; @@ -93,8 +99,10 @@ std::optional dim = valueDim.second; if (!dim.has_value()) { - // This is an index-typed value. - assert(value.getType().isIndex() && "expected index type"); + // This is an integer-typed value. + assert(value.getType().isIntOrIndex() && "expected integer type"); + if (value.getType() != resultType) + value = b.create(loc, resultType, value); operands.push_back(value); continue; } @@ -103,10 +111,16 @@ "expected dynamic dim"); if (value.getType().isa()) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create(loc, value, *dim)); + Value dimSize = b.create(loc, value, *dim); + if (!resultType.isIndex()) + dimSize = b.create(loc, resultType, dimSize); + operands.push_back(dimSize); } else if (value.getType().isa()) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create(loc, value, *dim)); + Value dimSize = b.create(loc, value, *dim); + if (!resultType.isIndex()) + dimSize = b.create(loc, resultType, dimSize); + operands.push_back(dimSize); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } @@ -125,5 +139,6 @@ return static_cast( operands[expr.getPosition() + boundMap.getNumDims()]); // General case: build Arith ops. - return static_cast(buildArithValue(b, loc, boundMap, operands)); + return static_cast( + buildArithValue(b, loc, resultType, boundMap, operands)); } diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -49,8 +49,8 @@ struct DimOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto dimOp = cast(op); assert(value == dimOp.getResult() && "invalid value"); @@ -79,8 +79,8 @@ struct RankOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -91,8 +91,8 @@ addEqBound(); } - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto forOp = cast(op); if (value == forOp.getInductionVar()) { diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp @@ -35,8 +35,8 @@ struct DimOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto dimOp = cast(op); assert(value == dimOp.getResult() && "invalid value"); @@ -100,8 +100,8 @@ struct RankOpInterface : public ValueBoundsOpInterface::ExternalModel { - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -43,7 +43,7 @@ #ifndef NDEBUG static void assertValidValueDim(Value value, std::optional dim) { - if (value.getType().isIndex()) { + if (value.getType().isIntOrIndex()) { assert(!dim.has_value() && "invalid dim value"); } else if (auto shapedType = value.getType().dyn_cast()) { assert(*dim >= 0 && "invalid dim value"); @@ -66,7 +66,7 @@ void ValueBoundsConstraintSet::addBound(BoundType type, Value value, AffineExpr expr) { - assert(value.getType().isIndex() && "expected index type"); + assert(value.getType().isIntOrIndex() && "expected integer type"); assert((value.isa() || value.cast().getOwner()->isEntryBlock()) && "unstructured control flow is not supported"); @@ -88,7 +88,7 @@ void ValueBoundsConstraintSet::addBound(BoundType type, Value value, OpFoldResult ofr) { - assert(value.getType().isIndex() && "expected index type"); + assert(value.getType().isIntOrIndex() && "expected integer type"); addBound(type, getPos(value), getExpr(ofr)); } @@ -112,16 +112,16 @@ if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); } else { - // Constant index value: return directly. + // Constant integer value: return directly. if (auto constInt = getConstantIntValue(value)) return builder.getAffineConstantExpr(*constInt); } // Dynamic value: add to constraint set. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + ValueDim valueDim = std::make_pair(value, dim.value_or(kIntegerValue)); if (valueDimToPosition.find(valueDim) == valueDimToPosition.end()) (void)insert(valueDim); - int64_t pos = getPos(value, dim.value_or(kIndexValue)); + int64_t pos = getPos(value, dim.value_or(kIntegerValue)); return pos < cstr.getNumDimVars() ? builder.getAffineDimExpr(pos) : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); @@ -155,7 +155,7 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, int64_t dim) const { #ifndef NDEBUG - assertValidValueDim(value, dim == kIndexValue + assertValidValueDim(value, dim == kIntegerValue ? std::nullopt : std::make_optional(dim)); #endif // NDEBUG @@ -180,7 +180,7 @@ int64_t dim = valueDim.second; // Check for static dim size. - if (dim != kIndexValue) { + if (dim != kIntegerValue) { auto shapedType = value.getType().cast(); if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) { addBound(BoundType::EQ, value, dim, @@ -190,7 +190,8 @@ } // Do not process any further if the stop condition is met. - auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); + auto maybeDim = + dim == kIntegerValue ? std::nullopt : std::make_optional(dim); if (stopCondition(value, maybeDim)) continue; @@ -200,8 +201,8 @@ dyn_cast(getOwnerOfValue(value)); if (!valueBoundsOp) continue; - if (dim == kIndexValue) { - valueBoundsOp.populateBoundsForIndexValue(value, *this); + if (dim == kIntegerValue) { + valueBoundsOp.populateBoundsForIntegerValue(value, *this); } else { valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this); } @@ -262,7 +263,7 @@ // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + ValueDim valueDim = std::make_pair(value, dim.value_or(kIntegerValue)); ValueBoundsConstraintSet cstr(valueDim); cstr.processWorklist(stopCondition); @@ -273,12 +274,12 @@ if (valueDim == p) return false; auto maybeDim = - p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); + p.second == kIntegerValue ? std::nullopt : std::make_optional(p.second); return !stopCondition(p.first, maybeDim); }); // Compute lower and upper bounds for `valueDim`. - int64_t pos = cstr.getPos(value, dim.value_or(kIndexValue)); + int64_t pos = cstr.getPos(value, dim.value_or(kIntegerValue)); SmallVector lb(1), ub(1); cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, /*getClosedUB=*/true); @@ -355,10 +356,10 @@ ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i]; Value value = valueDim.first; int64_t dim = valueDim.second; - if (dim == ValueBoundsConstraintSet::kIndexValue) { - // An index-type value is used: can be used directly in the affine.apply + if (dim == ValueBoundsConstraintSet::kIntegerValue) { + // An integer-type value is used: can be used directly in the affine.apply // op. - assert(value.getType().isIndex() && "expected index type"); + assert(value.getType().isIntOrIndex() && "expected integer type"); mapOperands.push_back(std::make_pair(value, std::nullopt)); continue; } @@ -428,9 +429,9 @@ // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); + ValueDim valueDim = std::make_pair(value, dim.value_or(kIntegerValue)); ValueBoundsConstraintSet cstr(valueDim); - int64_t pos = cstr.getPos(value, dim.value_or(kIndexValue)); + int64_t pos = cstr.getPos(value, dim.value_or(kIntegerValue)); if (stopCondition) { cstr.processWorklist(stopCondition); } else { diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -45,12 +45,24 @@ // CHECK-LABEL: func @arith_muli( // CHECK-SAME: %[[a:.*]]: index // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] -// CHECK: return %[[apply]] -func.func @arith_muli(%a: index) -> index { +// CHECK: %[[apply2:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: %[[casted:.*]] = arith.index_cast %[[apply2]] : index to i32 +// CHECK: return %[[apply]], %[[casted]] + +// CHECK-ARITH-LABEL: func @arith_muli( +// CHECK-ARITH-SAME: %[[a:.*]]: index +// CHECK-ARITH: %[[c5:.*]] = arith.constant 5 : index +// CHECK-ARITH: %[[mul:.*]] = arith.muli %[[a]], %[[c5]] +// CHECK-ARITH: %[[a_casted:.*]] = arith.index_cast %[[a]] : index to i32 +// CHECK-ARITH: %[[c5_i32:.*]] = arith.constant 5 : i32 +// CHECK-ARITH: %[[mul_i32:.*]] = arith.muli %[[a_casted]], %[[c5_i32]] +// CHECK-ARITH: return %[[mul]], %[[mul_i32]] +func.func @arith_muli(%a: index) -> (index, i32) { %0 = arith.constant 5 : index %1 = arith.muli %0, %a : index - %2 = "test.reify_bound"(%1) : (index) -> (index) - return %2 : index + %bound = "test.reify_bound"(%1) : (index) -> (index) + %bound_i32 = "test.reify_bound"(%1) : (index) -> (i32) + return %bound, %bound_i32 : index, i32 } // ----- diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -74,15 +74,15 @@ // Look for test.reify_bound ops. if (op->getName().getStringRef() == "test.reify_bound") { if (op->getNumOperands() != 1 || op->getNumResults() != 1 || - !op->getResultTypes()[0].isIndex()) { + -!op->getResultTypes()[0].isIntOrIndex()) { op->emitOpError("invalid op"); return WalkResult::skip(); } Value value = op->getOperand(0); - if (value.getType().isa() != + if (value.getType().isIntOrIndex() != !op->hasAttrOfType("dim")) { // Op should have "dim" attribute if and only if the operand is an - // index-typed value. + // integer/index-typed value. op->emitOpError("invalid op"); return WalkResult::skip(); } @@ -98,7 +98,7 @@ } // Get shape dimension (if any). - auto dim = value.getType().isIndex() + auto dim = value.getType().isIntOrIndex() ? std::nullopt : std::make_optional( op->getAttrOfType("dim").getInt()); @@ -131,12 +131,13 @@ auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( *boundType, value, dim, /*stopCondition=*/nullptr); if (succeeded(reifiedConst)) - reified = - FailureOr(rewriter.getIndexAttr(*reifiedConst)); + reified = FailureOr( + rewriter.getIntegerAttr(op->getResultTypes()[0], *reifiedConst)); } else { if (useArithOps) { - reified = arith::reifyValueBound(rewriter, op->getLoc(), *boundType, - value, dim, stopCondition); + reified = arith::reifyValueBound( + rewriter, op->getLoc(), *boundType, value, dim, stopCondition, + /*closedUB=*/false, op->getResultTypes()[0]); } else { reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim, stopCondition); @@ -148,12 +149,16 @@ } // Replace the op with the reified bound. - if (auto val = reified->dyn_cast()) { + if (Value val = reified->dyn_cast()) { + if (val.getType() != op->getResultTypes()[0]) + val = rewriter.create( + op->getLoc(), op->getResultTypes()[0], val); rewriter.replaceOp(op, val); return WalkResult::skip(); } - Value constOp = rewriter.create( - op->getLoc(), reified->get().cast().getInt()); + Value constOp = rewriter.create( + op->getLoc(), op->getResultTypes()[0], + reified->get().cast()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); }