diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -31,6 +31,24 @@ class OffsetSizeAndStrideOpInterface; +/// Return a vector of all the static or dynamic offsets of the op from provided +/// external static and dynamic offsets. +SmallVector getMixedOffsets(OffsetSizeAndStrideOpInterface op, + ArrayAttr staticOffsets, + ValueRange offsets); + +/// Return a vector of all the static or dynamic sizes of the op from provided +/// external static and dynamic sizes. +SmallVector getMixedSizes(OffsetSizeAndStrideOpInterface op, + ArrayAttr staticSizes, + ValueRange sizes); + +/// Return a vector of all the static or dynamic strides of the op from provided +/// external static and dynamic strides. +SmallVector getMixedStrides(OffsetSizeAndStrideOpInterface op, + ArrayAttr staticStrides, + ValueRange strides); + namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -165,16 +165,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::mlir::SmallVector<::mlir::OpFoldResult, 4> res; - unsigned numDynamic = 0; - unsigned count = $_op.static_offsets().size(); - for (unsigned idx = 0; idx < count; ++idx) { - if (isDynamicOffset(idx)) - res.push_back($_op.offsets()[numDynamic++]); - else - res.push_back($_op.static_offsets()[idx]); - } - return res; + return ::mlir::getMixedOffsets($_op, $_op.static_offsets(), + $_op.offsets()); }] >, InterfaceMethod< @@ -186,16 +178,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::mlir::SmallVector<::mlir::OpFoldResult, 4> res; - unsigned numDynamic = 0; - unsigned count = $_op.static_sizes().size(); - for (unsigned idx = 0; idx < count; ++idx) { - if (isDynamicSize(idx)) - res.push_back($_op.sizes()[numDynamic++]); - else - res.push_back($_op.static_sizes()[idx]); - } - return res; + return ::mlir::getMixedSizes($_op, $_op.static_sizes(), $_op.sizes()); }] >, InterfaceMethod< @@ -207,16 +190,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::mlir::SmallVector<::mlir::OpFoldResult, 4> res; - unsigned numDynamic = 0; - unsigned count = $_op.static_strides().size(); - for (unsigned idx = 0; idx < count; ++idx) { - if (isDynamicStride(idx)) - res.push_back($_op.strides()[numDynamic++]); - else - res.push_back($_op.static_strides()[idx]); - } - return res; + return ::mlir::getMixedStrides($_op, $_op.static_strides(), + $_op.strides()); }] >, diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -195,3 +195,48 @@ strides.push_back(one); } } + +SmallVector +mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, + ArrayAttr staticOffsets, ValueRange offsets) { + SmallVector res; + unsigned numDynamic = 0; + unsigned count = static_cast(staticOffsets.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicOffset(idx)) + res.push_back(offsets[numDynamic++]); + else + res.push_back(staticOffsets[idx]); + } + return res; +} + +SmallVector +mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, + ValueRange sizes) { + SmallVector res; + unsigned numDynamic = 0; + unsigned count = static_cast(staticSizes.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicSize(idx)) + res.push_back(sizes[numDynamic++]); + else + res.push_back(staticSizes[idx]); + } + return res; +} + +SmallVector +mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, + ArrayAttr staticStrides, ValueRange strides) { + SmallVector res; + unsigned numDynamic = 0; + unsigned count = static_cast(staticStrides.size()); + for (unsigned idx = 0; idx < count; ++idx) { + if (op.isDynamicStride(idx)) + res.push_back(strides[numDynamic++]); + else + res.push_back(staticStrides[idx]); + } + return res; +}