diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -39,6 +39,19 @@ SmallVectorImpl &staticVec, int64_t sentinel); +/// Return a vector of OpFoldResults given the special value +/// that indicates whether of the value is dynamic or not. +SmallVector getMixedValues(ArrayAttr staticValues, + ValueRange dynamicValues, + int64_t dynamicValueIndicator); + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues, + const int64_t dynamicValueIndicator); + /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -237,7 +237,30 @@ return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); }] >, - + StaticInterfaceMethod< + /*desc=*/"Return constant that indicates the offset is dynamic", + /*retTy=*/"int64_t", + /*methodName=*/"getDynamicOffsetIndicator", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] + >, + StaticInterfaceMethod< + /*desc=*/"Return constant that indicates the size is dynamic", + /*retTy=*/"int64_t", + /*methodName=*/"getDynamicSizeIndicator", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }] + >, + StaticInterfaceMethod< + /*desc=*/"Return constant that indicates the stride is dynamic", + /*retTy=*/"int64_t", + /*methodName=*/"getDynamicStrideIndicator", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] + >, InterfaceMethod< /*desc=*/[{ Assert the offset `idx` is a static constant and return its value. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" @@ -109,4 +110,40 @@ auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v1 == v2; } + +/// Return a vector of OpFoldResults given the special value +/// that indicates whether of the value is dynamic or not. +SmallVector getMixedValues(ArrayAttr staticValues, + ValueRange dynamicValues, + int64_t dynamicValueIndicator) { + SmallVector res; + res.reserve(staticValues.size()); + unsigned numDynamic = 0; + unsigned count = static_cast(staticValues.size()); + for (unsigned idx = 0; idx < count; ++idx) { + APInt value = staticValues[idx].cast().getValue(); + res.push_back(value.getSExtValue() == dynamicValueIndicator + ? OpFoldResult{dynamicValues[numDynamic++]} + : OpFoldResult{staticValues[idx]}); + } + return res; +} + +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues, + const int64_t dynamicValueIndicator) { + SmallVector staticValues; + SmallVector dynamicValues; + for (const auto &it : mixedValues) { + if (it.is()) { + staticValues.push_back(it.get().cast().getInt()); + } else { + staticValues.push_back(dynamicValueIndicator); + dynamicValues.push_back(it.get()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; +} + } // namespace mlir diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -182,72 +182,29 @@ SmallVector mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, ArrayAttr staticOffsets, ValueRange offsets) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticOffsets.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicOffset(idx)) - res.push_back(offsets[numDynamic++]); - else - res.push_back(staticOffsets[idx]); - } - return res; + return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator()); } SmallVector mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticSizes.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicSize(idx)) - res.push_back(sizes[numDynamic++]); - else - res.push_back(staticSizes[idx]); - } - return res; + return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator()); } SmallVector mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, ArrayAttr staticStrides, ValueRange strides) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticStrides.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicStride(idx)) - res.push_back(strides[numDynamic++]); - else - res.push_back(staticStrides[idx]); - } - return res; -} - -static std::pair> -decomposeMixedImpl(OpBuilder &b, - const SmallVectorImpl &mixedValues, - const int64_t dynamicValuePlaceholder) { - SmallVector staticValues; - SmallVector dynamicValues; - for (const auto &it : mixedValues) { - if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); - } else { - staticValues.push_back(ShapedType::kDynamicStrideOrOffset); - dynamicValues.push_back(it.get()); - } - } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; + return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator()); } std::pair> mlir::decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl &mixedValues) { - return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset); + return decomposeMixedValues(b, mixedValues, + ShapedType::kDynamicStrideOrOffset); } std::pair> mlir::decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl &mixedValues) { - return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize); + return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize); }