diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -39,6 +39,19 @@ SmallVectorImpl &staticVec, int64_t sentinel); +/// Return a vector of OpFoldResults given the special value +/// that indicates whether of the value is dynamic or not. +SmallVector getMixedValues(ArrayAttr staticValues, + ValueRange dynamicValues, + int64_t dynamicValueIndicator); + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues, + const int64_t dynamicValueIndicator); + /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpImplementation.h" namespace mlir { + /// Auxiliary range data structure to unpack the offset, size and stride /// operands into a list of triples. Such a list can be more convenient to /// manipulate. @@ -29,42 +30,17 @@ OpFoldResult stride; }; -class OffsetSizeAndStrideOpInterface; +/// Return a vector of all the static and dynamic offsets/strides. +SmallVector getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues); -/// Return a vector of OpFoldResults given the special value -/// that indicates whether of the value is dynamic or not. -SmallVector getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues, - int64_t dynamicValueIndicator); - -/// Return a vector of all the static or dynamic offsets of the op from provided -/// external static and dynamic offsets. -SmallVector getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, - ValueRange offsets); - -/// Return a vector of all the static or dynamic sizes of the op from provided -/// external static and dynamic sizes. -SmallVector getMixedSizes(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticSizes, - ValueRange sizes); - -/// Return a vector of all the static or dynamic strides of the op from provided -/// external static and dynamic strides. -SmallVector getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, - ValueRange strides); - -/// Decompose a vector of mixed static or dynamic values into the corresponding -/// pair of arrays. This is the inverse function of `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues, - const int64_t dynamicValueIndicator); +/// Return a vector of all the static and dynamic sizes. +SmallVector getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues); -/// Decompose a vector of mixed static or dynamic strides/offsets into the +/// Decompose a vector of mixed static and dynamic strides/offsets into the /// corresponding pair of arrays. This is the inverse function of -/// `getMixedStrides` and `getMixedOffsets`. +/// `getMixedStridesOrOffsets`. std::pair> decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl &mixedValues); @@ -75,12 +51,16 @@ decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl &mixedValues); +class OffsetSizeAndStrideOpInterface; + namespace detail { + LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); bool sameOffsetsSizesAndStrides( OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp); + } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -165,8 +165,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedOffsets($_op, $_op.static_offsets(), - $_op.offsets()); + return ::mlir::getMixedStridesOrOffsets($_op.static_offsets(), + $_op.offsets()); }] >, InterfaceMethod< @@ -178,7 +178,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedSizes($_op, $_op.static_sizes(), $_op.sizes()); + return ::mlir::getMixedSizes($_op.static_sizes(), $_op.sizes()); }] >, InterfaceMethod< @@ -190,8 +190,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedStrides($_op, $_op.static_strides(), - $_op.strides()); + return ::mlir::getMixedStridesOrOffsets($_op.static_strides(), + $_op.strides()); }] >, @@ -237,30 +237,6 @@ return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); }] >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the offset is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicOffsetIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] - >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the size is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicSizeIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }] - >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the stride is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicStrideIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] - >, InterfaceMethod< /*desc=*/[{ Assert the offset `idx` is a static constant and return its value. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" @@ -109,4 +110,40 @@ auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v1 == v2; } + +/// Return a vector of OpFoldResults given the special value +/// that indicates whether of the value is dynamic or not. +SmallVector getMixedValues(ArrayAttr staticValues, + ValueRange dynamicValues, + int64_t dynamicValueIndicator) { + SmallVector res; + res.reserve(staticValues.size()); + unsigned numDynamic = 0; + unsigned count = static_cast(staticValues.size()); + for (unsigned idx = 0; idx < count; ++idx) { + APInt value = staticValues[idx].cast().getValue(); + res.push_back(value.getSExtValue() == dynamicValueIndicator + ? OpFoldResult{dynamicValues[numDynamic++]} + : OpFoldResult{staticValues[idx]}); + } + return res; +} + +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues, + const int64_t dynamicValueIndicator) { + SmallVector staticValues; + SmallVector dynamicValues; + for (const auto &it : mixedValues) { + if (it.is()) { + staticValues.push_back(it.get().cast().getInt()); + } else { + staticValues.push_back(dynamicValueIndicator); + dynamicValues.push_back(it.get()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; +} + } // namespace mlir diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -180,54 +180,15 @@ } SmallVector -mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, - int64_t dynamicValueIndicator) { - SmallVector res; - res.reserve(staticValues.size()); - unsigned numDynamic = 0; - unsigned count = static_cast(staticValues.size()); - for (unsigned idx = 0; idx < count; ++idx) { - APInt value = staticValues[idx].cast().getValue(); - res.push_back(value.getSExtValue() == dynamicValueIndicator - ? OpFoldResult{dynamicValues[numDynamic++]} - : OpFoldResult{staticValues[idx]}); - } - return res; -} - -SmallVector -mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, ValueRange offsets) { - return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator()); -} - -SmallVector -mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, - ValueRange sizes) { - return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator()); +mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedValues(staticValues, dynamicValues, + ShapedType::kDynamicStrideOrOffset); } -SmallVector -mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, ValueRange strides) { - return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator()); -} - -std::pair> -mlir::decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues, - const int64_t dynamicValueIndicator) { - SmallVector staticValues; - SmallVector dynamicValues; - for (const auto &it : mixedValues) { - if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); - } else { - staticValues.push_back(dynamicValueIndicator); - dynamicValues.push_back(it.get()); - } - } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; +SmallVector mlir::getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamicSize); } std::pair> mlir::decomposeMixedStridesOrOffsets(