diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional getConstantIntValue(OpFoldResult ofr); - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool isEqualConstantInt(OpFoldResult ofr, int64_t value); - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. -bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); - /// Returns the identity value attribute associated with an AtomicRMWKind op. Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc); diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -0,0 +1,64 @@ +//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for dealing with static values, e.g., +// converting back and forth between Value and OpFoldResult. Such functionality +// is used in multiple dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef THIRD_PARTY_LLVM_LLVM_PROJECT_MLIR_INCLUDE_MLIR_DIALECT_UTILS_STATICVALUEUTILS_H_ +#define THIRD_PARTY_LLVM_LLVM_PROJECT_MLIR_INCLUDE_MLIR_DIALECT_UTILS_STATICVALUEUTILS_H_ + +#include "mlir/IR/OpDefinition.h" + +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Helper function to dispatch multiple OpFoldResults into either the +/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs). +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +SmallVector extractFromI64ArrayAttr(Attribute attr); + +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional getConstantIntValue(OpFoldResult ofr); + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool isEqualConstantInt(OpFoldResult ofr, int64_t value); + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwitdh and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType have no bitwidth. +bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); + +} // namespace mlir + +#endif // THIRD_PARTY_LLVM_LLVM_PROJECT_MLIR_INCLUDE_MLIR_DIALECT_UTILS_STATICVALUEUTILS_H_ diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -3388,14 +3389,6 @@ } }; -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -116,24 +117,6 @@ })); } -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) @@ -822,14 +805,6 @@ // PadTensorOp //===----------------------------------------------------------------------===// -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -109,6 +109,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -32,40 +33,6 @@ return builder.create(loc, type, value); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -33,38 +33,6 @@ using namespace mlir; -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; -} - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { - auto ofrValue = getConstantIntValue(ofr); - return ofrValue && *ofrValue == value; -} - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { - auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" @@ -516,32 +517,6 @@ // ExtractSliceOp //===----------------------------------------------------------------------===// -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. @@ -563,14 +538,6 @@ sourceRankedTensorType.getElementType()); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - Type ExtractSliceOp::inferResultType( RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRDialectUtils StructuredOpsUtils.cpp + StaticValueUtils.cpp LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -0,0 +1,76 @@ +//===- StructuredOpsUtils.cpp - Utilities for dealing with static values --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void mlir::dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + if (auto v = ofr.dyn_cast()) { + dynamicVec.push_back(v); + staticVec.push_back(sentinel); + return; + } + APInt apInt = ofr.dyn_cast().cast().getValue(); + staticVec.push_back(apInt.getSExtValue()); +} + +void mlir::dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + for (auto ofr : ofrs) + dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); +} + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +llvm::SmallVector mlir::extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { + Attribute attr = ofr.dyn_cast(); + // Note: isa+cast-like pattern allows writing the condition below as 1 line. + if (!attr && ofr.get().getDefiningOp()) + attr = ofr.get().getDefiningOp().getValue(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; +} + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { + auto ofrValue = getConstantIntValue(ofr); + return ofrValue && *ofrValue == value; +} + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { + auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); + if (cst1 && cst2 && *cst1 == *cst2) + return true; + auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); + return v1 && v2 && v1 == v2; +}