diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -269,13 +269,13 @@ // Return true if low padding is guaranteed to be 0. bool hasZeroLowPad() { return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) { - return mlir::isEqualConstantInt(ofr, 0); + return mlir::getConstantIntValue(ofr) == static_cast(0); }); } // Return true if high padding is guaranteed to be 0. bool hasZeroHighPad() { return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) { - return mlir::isEqualConstantInt(ofr, 0); + return mlir::getConstantIntValue(ofr) == static_cast(0); }); } }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional getConstantIntValue(OpFoldResult ofr); - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool isEqualConstantInt(OpFoldResult ofr, int64_t value); - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. -bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); - /// Returns the identity value attribute associated with an AtomicRMWKind op. Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc); diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -0,0 +1,58 @@ +//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for dealing with static values, e.g., +// converting back and forth between Value and OpFoldResult. Such functionality +// is used in multiple dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H +#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H + +#include "mlir/IR/OpDefinition.h" + +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Helper function to dispatch multiple OpFoldResults into either the +/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs). +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +SmallVector extractFromI64ArrayAttr(Attribute attr); + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +Optional getConstantIntValue(OpFoldResult ofr); + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwitdh and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType have no bitwidth. +bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); + +} // namespace mlir + +#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -13,6 +13,7 @@ #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -30,8 +31,6 @@ class OffsetSizeAndStrideOpInterface; -bool isEqualConstantInt(OpFoldResult ofr, int64_t value); - namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -444,7 +444,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) { - return ::mlir::isEqualConstantInt(ofr, 1); + return ::mlir::getConstantIntValue(ofr) == static_cast(1); }); }] >, @@ -456,7 +456,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) { - return ::mlir::isEqualConstantInt(ofr, 0); + return ::mlir::getConstantIntValue(ofr) == static_cast(0); }); }] >, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -3388,14 +3389,6 @@ } }; -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -116,24 +117,6 @@ })); } -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) @@ -822,14 +805,6 @@ // PadTensorOp //===----------------------------------------------------------------------===// -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -109,6 +109,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -814,8 +814,8 @@ readInBounds.push_back(false); // Write is out-of-bounds if low padding > 0. writeInBounds.push_back( - isEqualConstantIntOrValue(padOp.getMixedLowPad()[i], - rewriter.getIndexAttr(0))); + getConstantIntValue(padOp.getMixedLowPad()[i]) == + static_cast(0)); } else { // Neither source nor result dim of padOp is static. Cannot vectorize // the copy. @@ -1098,9 +1098,9 @@ SmallVector expectedSizes(tensorRank - vecRank, 1); expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); if (!llvm::all_of( - llvm::zip(insertOp.getMixedSizes(), expectedSizes), - [](auto it) { return isEqualConstantInt(std::get<0>(it), - std::get<1>(it)); })) + llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { + return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); + })) return failure(); // Generate TransferReadOp: Read entire source tensor and add high padding. diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR MLIRMemRefUtils diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -32,40 +33,6 @@ return builder.create(loc, type, value); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -33,38 +33,6 @@ using namespace mlir; -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; -} - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { - auto ofrValue = getConstantIntValue(ofr); - return ofrValue && *ofrValue == value; -} - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { - auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRCastInterfaces + MLIRDialectUtils MLIRIR MLIRSideEffectInterfaces MLIRSupport diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" @@ -516,32 +517,6 @@ // ExtractSliceOp //===----------------------------------------------------------------------===// -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. @@ -563,14 +538,6 @@ sourceRankedTensorType.getElementType()); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - Type ExtractSliceOp::inferResultType( RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, @@ -890,17 +857,16 @@ ShapedType shapedType) { OpBuilder b(op.getContext()); for (OpFoldResult ofr : op.getMixedOffsets()) - if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0))) + if (getConstantIntValue(ofr) != static_cast(0)) return failure(); // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip // is appropriate. auto shape = shapedType.getShape(); for (auto it : llvm::zip(op.getMixedSizes(), shape)) - if (!isEqualConstantIntOrValue(std::get<0>(it), - b.getIndexAttr(std::get<1>(it)))) + if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) return failure(); for (OpFoldResult ofr : op.getMixedStrides()) - if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1))) + if (getConstantIntValue(ofr) != static_cast(1)) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRDialectUtils StructuredOpsUtils.cpp + StaticValueUtils.cpp LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -0,0 +1,78 @@ +//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Matchers.h" + +#include "llvm/ADT/APSInt.h" + +using llvm::None; +using llvm::Optional; +using llvm::SmallVector; + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void mlir::dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + if (auto v = ofr.dyn_cast()) { + dynamicVec.push_back(v); + staticVec.push_back(sentinel); + return; + } + APInt apInt = ofr.dyn_cast().cast().getValue(); + staticVec.push_back(apInt.getSExtValue()); +} + +void mlir::dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + for (auto ofr : ofrs) + dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); +} + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +SmallVector mlir::extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +Optional mlir::getConstantIntValue(OpFoldResult ofr) { + // Case 1: Check for Constant integer. + if (auto val = ofr.dyn_cast()) { + APSInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal.getSExtValue(); + return None; + } + // Case 2: Check for IntegerAttr. + Attribute attr = ofr.dyn_cast(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return None; +} + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { + auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); + if (cst1 && cst2 && *cst1 == *cst2) + return true; + auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); + return v1 && v1 == v2; +}