diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpImplementation.h" namespace mlir { + /// Auxiliary range data structure to unpack the offset, size and stride /// operands into a list of triples. Such a list can be more convenient to /// manipulate. @@ -29,31 +30,19 @@ OpFoldResult stride; }; -class OffsetSizeAndStrideOpInterface; - /// Return a vector of OpFoldResults given the special value /// that indicates whether of the value is dynamic or not. SmallVector getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, int64_t dynamicValueIndicator); -/// Return a vector of all the static or dynamic offsets of the op from provided -/// external static and dynamic offsets. -SmallVector getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, - ValueRange offsets); +/// Return a vector of all the static and dynamic offsets/strides. +SmallVector getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues); -/// Return a vector of all the static or dynamic sizes of the op from provided -/// external static and dynamic sizes. -SmallVector getMixedSizes(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticSizes, - ValueRange sizes); - -/// Return a vector of all the static or dynamic strides of the op from provided -/// external static and dynamic strides. -SmallVector getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, - ValueRange strides); +/// Return a vector of all the static and dynamic sizes. +SmallVector getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues); /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. @@ -62,9 +51,9 @@ const SmallVectorImpl &mixedValues, const int64_t dynamicValueIndicator); -/// Decompose a vector of mixed static or dynamic strides/offsets into the +/// Decompose a vector of mixed static and dynamic strides/offsets into the /// corresponding pair of arrays. This is the inverse function of -/// `getMixedStrides` and `getMixedOffsets`. +/// `getMixedStridesOrOffsets`. std::pair> decomposeMixedStridesOrOffsets( OpBuilder &b, const SmallVectorImpl &mixedValues); @@ -75,12 +64,16 @@ decomposeMixedSizes(OpBuilder &b, const SmallVectorImpl &mixedValues); +class OffsetSizeAndStrideOpInterface; + namespace detail { + LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); bool sameOffsetsSizesAndStrides( OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp); + } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -165,8 +165,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedOffsets($_op, $_op.static_offsets(), - $_op.offsets()); + return ::mlir::getMixedStridesOrOffsets($_op.static_offsets(), + $_op.offsets()); }] >, InterfaceMethod< @@ -178,7 +178,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedSizes($_op, $_op.static_sizes(), $_op.sizes()); + return ::mlir::getMixedSizes($_op.static_sizes(), $_op.sizes()); }] >, InterfaceMethod< @@ -190,8 +190,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedStrides($_op, $_op.static_strides(), - $_op.strides()); + return ::mlir::getMixedStridesOrOffsets($_op.static_strides(), + $_op.strides()); }] >, @@ -237,30 +237,6 @@ return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); }] >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the offset is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicOffsetIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] - >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the size is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicSizeIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }] - >, - StaticInterfaceMethod< - /*desc=*/"Return constant that indicates the stride is dynamic", - /*retTy=*/"int64_t", - /*methodName=*/"getDynamicStrideIndicator", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }] - >, InterfaceMethod< /*desc=*/[{ Assert the offset `idx` is a static constant and return its value. diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -181,7 +181,7 @@ SmallVector mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, - int64_t dynamicValueIndicator) { + const int64_t dynamicValueIndicator) { SmallVector res; res.reserve(staticValues.size()); unsigned numDynamic = 0; @@ -196,21 +196,15 @@ } SmallVector -mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticOffsets, ValueRange offsets) { - return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator()); +mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedValues(staticValues, dynamicValues, + ShapedType::kDynamicStrideOrOffset); } -SmallVector -mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, - ValueRange sizes) { - return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator()); -} - -SmallVector -mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op, - ArrayAttr staticStrides, ValueRange strides) { - return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator()); +SmallVector mlir::getMixedSizes(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamicSize); } std::pair>