diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpImplementation.h" namespace mlir { + /// Auxiliary range data structure to unpack the offset, size and stride /// operands into a list of triples. Such a list can be more convenient to /// manipulate. @@ -29,29 +30,17 @@ OpFoldResult stride; }; -class OffsetSizeAndStrideOpInterface; - -/// Return a vector of all the static or dynamic offsets of the op from provided -/// external static and dynamic offsets. -SmallVector getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, - ValueRange offsets); +/// Return a vector of all the static and dynamic offsets/strides. +SmallVector getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues); -/// Return a vector of all the static or dynamic sizes of the op from provided -/// external static and dynamic sizes. -SmallVector getMixedSizes(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticSizes, - ValueRange sizes); +/// Return a vector of all the static and dynamic sizes. +SmallVector getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues); -/// Return a vector of all the static or dynamic strides of the op from provided -/// external static and dynamic strides. -SmallVector getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, - ValueRange strides); - -/// Decompose a vector of mixed static or dynamic strides/offsets into the +/// Decompose a vector of mixed static and dynamic strides/offsets into the /// corresponding pair of arrays. This is the inverse function of -/// `getMixedStrides` and `getMixedOffsets`. +/// `getMixedStridesOrOffsets`. std::pair> decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl &mixedValues); @@ -62,12 +51,16 @@ decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl &mixedValues); +class OffsetSizeAndStrideOpInterface; + namespace detail { + LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); bool sameOffsetsSizesAndStrides( OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp); + } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -165,8 +165,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedOffsets($_op, $_op.static_offsets(), - $_op.offsets()); + return ::mlir::getMixedStridesOrOffsets($_op.static_offsets(), + $_op.offsets()); }] >, InterfaceMethod< @@ -178,7 +178,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedSizes($_op, $_op.static_sizes(), $_op.sizes()); + return ::mlir::getMixedSizes($_op.static_sizes(), $_op.sizes()); }] >, InterfaceMethod< @@ -190,8 +190,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedStrides($_op, $_op.static_strides(), - $_op.strides()); + return ::mlir::getMixedStridesOrOffsets($_op.static_strides(), + $_op.strides()); }] >, diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -179,49 +179,32 @@ return true; } -SmallVector -mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, ValueRange offsets) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticOffsets.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicOffset(idx)) - res.push_back(offsets[numDynamic++]); - else - res.push_back(staticOffsets[idx]); +static SmallVector +getMixedImpl(ArrayAttr staticValues, ValueRange dynamicValues, + const int64_t dynamicValuePlaceholder) { + int64_t idxDynamic = 0; + SmallVector result; + for (const Attribute &staticAttr : staticValues) { + int64_t staticInt = staticAttr.cast().getInt(); + if (staticInt == dynamicValuePlaceholder) { + result.push_back(dynamicValues[idxDynamic++]); + } else { + result.push_back(staticAttr); + } } - return res; + return result; } SmallVector -mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, - ValueRange sizes) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticSizes.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicSize(idx)) - res.push_back(sizes[numDynamic++]); - else - res.push_back(staticSizes[idx]); - } - return res; +mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedImpl(staticValues, dynamicValues, + ShapedType::kDynamicStrideOrOffset); } -SmallVector -mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, ValueRange strides) { - SmallVector res; - unsigned numDynamic = 0; - unsigned count = static_cast(staticStrides.size()); - for (unsigned idx = 0; idx < count; ++idx) { - if (op.isDynamicStride(idx)) - res.push_back(strides[numDynamic++]); - else - res.push_back(staticStrides[idx]); - } - return res; +SmallVector mlir::getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedImpl(staticValues, dynamicValues, ShapedType::kDynamicSize); } static std::pair>