diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -63,14 +63,18 @@ /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. OpFoldResult getAsOpFoldResult(Value val); - /// Given an array of values, try to extract a constant Attribute from each /// value. If this fails, return the original value. SmallVector getAsOpFoldResult(ValueRange values); - /// Convert `arrayAttr` to a vector of OpFoldResult. SmallVector getAsOpFoldResult(ArrayAttr arrayAttr); +/// Convert int64_t to integer attributes of index type and return them as +/// OpFoldResult. +OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val); +SmallVector getAsIndexOpFoldResult(MLIRContext *ctx, + ArrayRef values); + /// If ofr is a constant integer or an IntegerAttr, return the integer. std::optional getConstantIntValue(OpFoldResult ofr); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -102,6 +102,16 @@ return res; } +OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) { + return IntegerAttr::get(IndexType::get(ctx), val); +} + +SmallVector getAsIndexOpFoldResult(MLIRContext *ctx, + ArrayRef values) { + return llvm::to_vector<4>(llvm::map_range( + values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); })); +} + /// If ofr is a constant integer or an IntegerAttr, return the integer. std::optional getConstantIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer.