diff --git a/clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp --- a/clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp @@ -143,6 +143,38 @@ } } +static std::optional> +getRequestRegionOffsetAndCount(const MemRegion *const MR, const CallEvent &CE) { + if (CE.getNumArgs() < 2) + return std::nullopt; + + ProgramStateRef State = CE.getState(); + SValBuilder &SVB = State->getStateManager().getSValBuilder(); + ASTContext &ASTCtx = SVB.getContext(); + + QualType RequestTy = CE.getArgExpr(1)->getType()->getPointeeType(); + auto RequestRegionCount = + getDynamicElementCountWithOffset(State, CE.getArgSVal(1), RequestTy) + .getAs(); + if (!RequestRegionCount) + return std::nullopt; + + CharUnits TypeSizeInChars = ASTCtx.getTypeSizeInChars(RequestTy); + + // MPI_Request as a handle does not have to be of non-zero size. + int64_t TypeSizeInBits = + (TypeSizeInChars.isZero() ? 1 : TypeSizeInChars.getQuantity()) * + ASTCtx.getCharWidth(); + + RegionOffset RequestRegionOffset = MR->getAsOffset(); + if (RequestRegionOffset.hasSymbolicOffset()) + return std::nullopt; + + return std::make_pair( + SVB.makeArrayIndex(RequestRegionOffset.getOffset() / TypeSizeInBits), + RequestRegionCount->getValue()); +} + void MPIChecker::allRegionsUsedByWait( llvm::SmallVector &ReqRegions, const MemRegion *const MR, const CallEvent &CE, CheckerContext &Ctx) const { @@ -161,20 +193,34 @@ return; } - DefinedOrUnknownSVal ElementCount = getDynamicElementCount( - Ctx.getState(), SuperRegion, Ctx.getSValBuilder(), - CE.getArgExpr(1)->getType()->getPointeeType()); - const llvm::APSInt &ArrSize = - ElementCount.castAs().getValue(); + auto RequestRegionOffsetAndCount = getRequestRegionOffsetAndCount(MR, CE); + if (!RequestRegionOffsetAndCount) + return; + + auto [RegionOffset, RegionCount] = *RequestRegionOffsetAndCount; - for (size_t i = 0; i < ArrSize; ++i) { - const NonLoc Idx = Ctx.getSValBuilder().makeArrayIndex(i); + QualType MPIReqTy = CE.getArgExpr(1)->getType()->getPointeeType(); + SValBuilder &SVB = Ctx.getSValBuilder(); - const ElementRegion *const ER = RegionManager.getElementRegion( - CE.getArgExpr(1)->getType()->getPointeeType(), Idx, SuperRegion, - Ctx.getASTContext()); + auto RequestedCountSVal = CE.getArgSVal(0).getAs(); + if (!RequestedCountSVal) + return; - ReqRegions.push_back(ER->getAs()); + const llvm::APSInt &RequestedCount = RequestedCountSVal->getValue(); + // TODO: i >= RegionCount is an OOB UB, we could report it here but a better + // approach is adding this constraint as a summary into generic checker like + // StdCLibraryFunctions + for (size_t i = 0; i < RegionCount && i < RequestedCount; ++i) { + auto RegionIndex = + SVB.evalBinOp(Ctx.getState(), BO_Add, SVB.makeArrayIndex(i), + RegionOffset, SVB.getArrayIndexType()) + .getAs(); + if (RegionIndex) { + const ElementRegion *const RequestRegion = + RegionManager.getElementRegion(MPIReqTy, *RegionIndex, SuperRegion, + Ctx.getASTContext()); + ReqRegions.push_back(RequestRegion); + } } } else if (FuncClassifier->isMPI_Wait(CE.getCalleeIdentifier())) { ReqRegions.push_back(MR); diff --git a/clang/test/Analysis/mpichecker.cpp b/clang/test/Analysis/mpichecker.cpp --- a/clang/test/Analysis/mpichecker.cpp +++ b/clang/test/Analysis/mpichecker.cpp @@ -272,6 +272,55 @@ MPI_Wait(&rs.req2, MPI_STATUS_IGNORE); } // no error +void nestedRequestWithCount() { + typedef struct { + MPI_Request req[3]; + MPI_Request req2; + } ReqStruct; + + ReqStruct rs; + int rank = 0; + double buf = 0; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[0]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[1]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[2]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req2); + MPI_Waitall(2, rs.req, MPI_STATUSES_IGNORE); + MPI_Waitall(1, rs.req + 2, MPI_STATUSES_IGNORE); + MPI_Wait(&rs.req2, MPI_STATUS_IGNORE); +} // no error + +void nestedRequestWithCountMissingNonBlockingWait() { + typedef struct { + MPI_Request req[3]; + MPI_Request req2; + } ReqStruct; + + ReqStruct rs; + int rank = 0; + double buf = 0; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[0]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[1]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req[2]); + MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD, + &rs.req2); + MPI_Waitall(1, rs.req, MPI_STATUSES_IGNORE); + // MPI_Waitall(1, rs.req + 1, MPI_STATUSES_IGNORE); + MPI_Waitall(1, rs.req + 2, MPI_STATUSES_IGNORE); + MPI_Wait(&rs.req2, MPI_STATUS_IGNORE); +} // expected-warning{{Request 'rs.req[1]' has no matching wait.}} + void singleRequestInWaitall() { MPI_Request r; int rank = 0;