diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -32,6 +32,11 @@ OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp); +/// Helper method to compute the number of dynamic entries of `staticVals`, +/// up to `idx`. +unsigned getNumDynamicEntriesUpToIdx(ArrayRef staticVals, + unsigned idx); + } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -279,8 +279,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - getStaticOffsets(), ::mlir::ShapedType::isDynamic, idx); + auto numDynamic = ::mlir::detail::getNumDynamicEntriesUpToIdx( + getStaticOffsets(), idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic; }] >, @@ -295,8 +295,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert($_op.isDynamicSize(idx) && "expected dynamic size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - getStaticSizes(), ::mlir::ShapedType::isDynamic, idx); + auto numDynamic = ::mlir::detail::getNumDynamicEntriesUpToIdx( + getStaticSizes(), idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + getOffsets().size() + numDynamic; }] @@ -312,32 +312,12 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert($_op.isDynamicStride(idx) && "expected dynamic stride"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - getStaticStrides(), ::mlir::ShapedType::isDynamic, idx); + auto numDynamic = ::mlir::detail::getNumDynamicEntriesUpToIdx( + getStaticStrides(), idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + getOffsets().size() + getSizes().size() + numDynamic; }] >, - InterfaceMethod< - /*desc=*/[{ - Helper method to compute the number of dynamic entries of `staticVals`, - up to `idx` using `isDynamic` to determine whether an entry is dynamic. - }], - /*retTy=*/"unsigned", - /*methodName=*/"getNumDynamicEntriesUpToIdx", - /*args=*/(ins "::llvm::ArrayRef":$staticVals, - "::llvm::function_ref":$isDynamic, - "unsigned":$idx), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return std::count_if( - staticVals.begin(), staticVals.begin() + idx, - [&](int64_t val) { - return isDynamic(val); - }); - }] - >, - InterfaceMethod< /*desc=*/[{ Assert the offset `idx` is dynamic and return its value. @@ -417,29 +397,6 @@ >, ]; - let extraClassDeclaration = [{ - static unsigned getOffsetOperandGroupPosition() { return 0; } - static unsigned getSizeOperandGroupPosition() { return 1; } - static unsigned getStrideOperandGroupPosition() { return 2; } - static ::llvm::StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static ::llvm::StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static ::llvm::StringRef getStaticStridesAttrName() { - return "static_strides"; - } - static ::llvm::ArrayRef<::llvm::StringRef> getSpecialAttrNames() { - static ::llvm::SmallVector<::llvm::StringRef, 4> names{ - ::mlir::OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), - ::mlir::OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), - ::mlir::OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), - ::mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()}; - return names; - } - }]; - let verify = [{ return ::mlir::detail::verifyOffsetSizeAndStrideOp( ::mlir::cast<::mlir::OffsetSizeAndStrideOpInterface>($_op)); diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -196,3 +196,9 @@ return false; return true; } + +unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef staticVals, + unsigned idx) { + return std::count_if(staticVals.begin(), staticVals.begin() + idx, + [&](int64_t val) { return ShapedType::isDynamic(val); }); +}