diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -49,6 +49,19 @@ ArrayAttr staticStrides, ValueRange strides); +/// Decompose a vector of mixed static or dynamic strides/offsets into the +/// corresponding pair of arrays. This is the inverse function of +/// `getMixedStrides` and `getMixedOffsets`. +std::pair> decomposeMixedStridesOrOffsets( + OpBuilder &b, const SmallVectorImpl &mixedValues); + +/// Decompose a vector of mixed static or dynamic strides/offsets into the +/// corresponding pair of arrays. This is the inverse function of +/// `getMixedSizes`. +std::pair> +decomposeMixedSizes(OpBuilder &b, + const SmallVectorImpl &mixedValues); + namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -242,3 +242,31 @@ } return res; } + +static std::pair> +decomposeMixedImpl(OpBuilder &b, + const SmallVectorImpl &mixedValues, + const int64_t dynamicValuePlaceholder) { + SmallVector staticValues; + SmallVector dynamicValues; + for (const auto &it : mixedValues) { + if (it.is()) { + staticValues.push_back(it.get().cast().getInt()); + } else { + staticValues.push_back(ShapedType::kDynamicStrideOrOffset); + dynamicValues.push_back(it.get()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; +} + +std::pair> mlir::decomposeMixedStridesOrOffsets( + OpBuilder &b, const SmallVectorImpl &mixedValues) { + return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset); +} + +std::pair> +mlir::decomposeMixedSizes(OpBuilder &b, + const SmallVectorImpl &mixedValues) { + return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize); +}