Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -501,9 +501,9 @@ } const ScopArrayInfo *ScopArrayInfo::getFromId(isl::id Id) { - void *User = Id.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + if (ScopArrayInfo **User = Id.get_user()) + return *User; + return nullptr; } void MemoryAccess::wrapConstantDimensions() { @@ -666,16 +666,12 @@ const ScopArrayInfo *MemoryAccess::getOriginalScopArrayInfo() const { isl::id ArrayId = getArrayId(); - void *User = ArrayId.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + return ScopArrayInfo::getFromId(ArrayId); } const ScopArrayInfo *MemoryAccess::getLatestScopArrayInfo() const { isl::id ArrayId = getLatestArrayId(); - void *User = ArrayId.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + return ScopArrayInfo::getFromId(ArrayId); } isl::id MemoryAccess::getOriginalArrayId() const { @@ -1212,8 +1208,8 @@ assert(NewAccessSpace.has_tuple_id(isl::dim::set) && "Must specify the array that is accessed"); isl::id NewArrayId = NewAccessSpace.get_tuple_id(isl::dim::set); - auto *SAI = static_cast(NewArrayId.get_user()); - assert(SAI && "Must set a ScopArrayInfo"); + auto *SAI = *NewArrayId.get_user(); + assert(NewArrayId.get_user() && "Must set a ScopArrayInfo"); if (SAI->isArrayKind() && SAI->getBasePtrOriginSAI()) { InvariantEquivClassTy *EqClass = @@ -1998,8 +1994,7 @@ ParameterName = getIslCompatibleName("", ParameterName, ""); } - isl::id Id = isl::id::alloc(getIslCtx(), ParameterName, - const_cast((const void *)Parameter)); + isl::id Id = isl::id::alloc(getIslCtx(), ParameterName, Parameter); ParameterIds[Parameter] = Id; } @@ -2085,9 +2080,9 @@ // Project out newly introduced parameters as they are not otherwise useful. if (!NewParams.empty()) { for (unsigned u = 0; u < isl_set_n_param(AssumptionCtx); u++) { - auto *Id = isl_set_get_dim_id(AssumptionCtx, isl_dim_param, u); - auto *Param = static_cast(isl_id_get_user(Id)); - isl_id_free(Id); + auto Id = + isl::manage(isl_set_get_dim_id(AssumptionCtx, isl_dim_param, u)); + auto *Param = *Id.get_user(); if (!NewParams.count(Param)) continue; Index: lib/CodeGen/IslAst.cpp =================================================================== --- lib/CodeGen/IslAst.cpp +++ lib/CodeGen/IslAst.cpp @@ -263,17 +263,17 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, void *User) { AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; - IslAstUserPayload *Payload = new IslAstUserPayload(); - isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); - Id = isl_id_set_free_user(Id, freeIslAstUserPayload); - BuildInfo->LastForNodeId = Id; + auto Id = isl::id::alloc(isl::ctx(isl_ast_build_get_ctx(Build)), "", + IslAstUserPayload()); + auto *Payload = Id.get_user(); + BuildInfo->LastForNodeId = Id.get(); // Test for parallelism only if we are not already inside a parallel loop if (!BuildInfo->InParallelFor) BuildInfo->InParallelFor = Payload->IsOutermostParallel = astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); - return Id; + return Id.release(); } // This method is executed after the construction of a for node. @@ -286,15 +286,16 @@ static __isl_give isl_ast_node * astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, void *User) { - isl_id *Id = isl_ast_node_get_annotation(Node); + isl::id Id = isl::manage(isl_ast_node_get_annotation(Node)); assert(Id && "Post order visit assumes annotated for nodes"); - IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); - assert(Payload && "Post order visit assumes annotated for nodes"); + assert(Id.get_user() && + "Post order visit assumes annotated for nodes"); + IslAstUserPayload *Payload = Id.get_user(); AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; assert(!Payload->Build && "Build environment already set"); Payload->Build = isl_ast_build_copy(Build); - Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); + Payload->IsInnermost = (Id.get() == BuildInfo->LastForNodeId); // Innermost loops that are surrounded by parallel loops have not yet been // tested for parallelism. Test them here to ensure we check all innermost @@ -311,7 +312,6 @@ if (Payload->IsOutermostParallel) BuildInfo->InParallelFor = false; - isl_id_free(Id); return Node; } @@ -345,13 +345,13 @@ void *User) { assert(!isl_ast_node_get_annotation(Node) && "Node already annotated"); - IslAstUserPayload *Payload = new IslAstUserPayload(); - isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); - Id = isl_id_set_free_user(Id, freeIslAstUserPayload); + auto Id = isl::id::alloc(isl::ctx(isl_ast_build_get_ctx(Build)), "", + IslAstUserPayload()); + auto *Payload = Id.get_user(); Payload->Build = isl_ast_build_copy(Build); - return isl_ast_node_set_annotation(Node, Id); + return isl_ast_node_set_annotation(Node, Id.release()); } // Build alias check condition given a pair of minimal/maximal access. @@ -606,8 +606,7 @@ isl_id *Id = isl_ast_node_get_annotation(Node); if (!Id) return nullptr; - IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); - isl_id_free(Id); + IslAstUserPayload *Payload = isl::manage(Id).get_user(); return Payload; } @@ -693,7 +692,7 @@ isl::ast_expr NodeExpr = AstNode.user_get_expr(); isl::ast_expr CallExpr = NodeExpr.get_op_arg(0); isl::id CallExprId = CallExpr.get_id(); - ScopStmt *AccessStmt = (ScopStmt *)CallExprId.get_user(); + ScopStmt *AccessStmt = *CallExprId.get_user(); P = isl_printer_start_line(P); P = isl_printer_print_str(P, AccessStmt->getBaseName()); Index: lib/CodeGen/IslNodeBuilder.cpp =================================================================== --- lib/CodeGen/IslNodeBuilder.cpp +++ lib/CodeGen/IslNodeBuilder.cpp @@ -292,7 +292,7 @@ static void addReferencesFromStmtSet(isl::set Set, struct SubtreeReferences *UserPtr) { isl::id Id = Set.get_tuple_id(); - auto *Stmt = static_cast(Id.get_user()); + auto *Stmt = *Id.get_user(); return addReferencesFromStmt(Stmt, UserPtr); } @@ -406,7 +406,7 @@ isl_ast_expr *StmtExpr = isl_ast_expr_get_op_arg(Expr, 0); isl_id *Id = isl_ast_expr_get_id(StmtExpr); isl_ast_expr_free(StmtExpr); - ScopStmt *Stmt = (ScopStmt *)isl_id_get_user(Id); + ScopStmt *Stmt = *isl::manage_copy(Id).get_user(); std::vector VLTS(IVS.size()); isl_union_set *Domain = isl_union_set_from_set(Stmt->getDomain().release()); @@ -440,7 +440,7 @@ return; } if (strcmp(isl_id_get_name(Id), "Inter iteration alias-free") == 0) { - auto *BasePtr = static_cast(isl_id_get_user(Id)); + auto *BasePtr = *isl::manage_copy(Id).get_user(); Annotator.addInterIterationAliasFreeBasePtr(BasePtr); } create(Child); @@ -761,8 +761,7 @@ isl::ast_expr StmtExpr = Expr.get_op_arg(0); isl::id Id = StmtExpr.get_id(); - ScopStmt *Stmt = - static_cast(isl_id_get_user(Id.get())); + ScopStmt *Stmt = *Id.get_user(); isl::set StmtDom = Stmt->getDomain(); for (auto *MA : *Stmt) { if (MA->isLatestPartialAccess()) @@ -993,7 +992,7 @@ LTS.insert(OutsideLoopIterations.begin(), OutsideLoopIterations.end()); - Stmt = (ScopStmt *)isl_id_get_user(Id); + Stmt = *isl::manage_copy(Id).get_user(); auto *NewAccesses = createNewAccesses(Stmt, User); if (Stmt->isCopyStmt()) { generateCopyStmt(Stmt, NewAccesses); @@ -1049,7 +1048,7 @@ bool IslNodeBuilder::materializeValue(isl_id *Id) { // If the Id is already mapped, skip it. if (!IDToValue.count(Id)) { - auto *ParamSCEV = (const SCEV *)isl_id_get_user(Id); + auto *ParamSCEV = *isl::manage_copy(Id).get_user(); Value *V = nullptr; // Parameters could refer to invariant loads that need to be Index: lib/External/isl/include/isl/isl-noexceptions.h =================================================================== --- lib/External/isl/include/isl/isl-noexceptions.h +++ lib/External/isl/include/isl/isl-noexceptions.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -976,10 +977,13 @@ inline std::string to_str() const; inline void dump() const; - static inline isl::id alloc(isl::ctx ctx, const std::string &name, void * user); + template + static inline isl::id alloc(isl::ctx ctx, const std::string &name, T &&user); + static inline isl::id alloc(isl::ctx ctx, const std::string &name, + std::nullptr_t); inline uint32_t get_hash() const; inline std::string get_name() const; - inline void * get_user() const; + template inline T *get_user() const; }; // declarations for isl::id_list @@ -6956,9 +6960,22 @@ } -isl::id id::alloc(isl::ctx ctx, const std::string &name, void * user) +isl::id id::alloc(isl::ctx ctx, const std::string &name, std::nullptr_t) { - auto res = isl_id_alloc(ctx.release(), name.c_str(), user); + auto res = isl_id_alloc(ctx.release(), name.c_str(), nullptr); + return manage(res); +} + +static void freeAny(void *User) { + delete static_cast(User); +} + +template +isl::id id::alloc(isl::ctx ctx, const std::string &name, T &&user) +{ + auto res = isl_id_alloc(ctx.release(), name.c_str(), + new llvm::Any(std::forward(user))); + res = isl_id_set_free_user(res, freeAny); return manage(res); } @@ -6975,10 +6992,13 @@ return tmp; } -void * id::get_user() const -{ - auto res = isl_id_get_user(get()); - return res; +template T *id::get_user() const { + auto Any = static_cast(isl_id_get_user(get())); + if (!Any) + return nullptr; + T *User = llvm::any_cast(Any); + ISLPP_ASSERT(User, "Tried to fetch wrong user type from isl_id"); + return User; } // implementations for isl::id_list Index: lib/Transform/ForwardOpTree.cpp =================================================================== --- lib/Transform/ForwardOpTree.cpp +++ lib/Transform/ForwardOpTree.cpp @@ -246,7 +246,7 @@ for (isl::map Map : MustKnown.get_map_list()) { // Get the array this is accessing. isl::id ArrayId = Map.get_tuple_id(isl::dim::out); - ScopArrayInfo *SAI = static_cast(ArrayId.get_user()); + ScopArrayInfo *SAI = *ArrayId.get_user(); // No support for generation of indirect array accesses. if (SAI->getBasePtrOriginSAI()) @@ -348,7 +348,7 @@ MemoryAccess *makeReadArrayAccess(ScopStmt *Stmt, LoadInst *LI, isl::map AccessRelation) { isl::id ArrayId = AccessRelation.get_tuple_id(isl::dim::out); - ScopArrayInfo *SAI = reinterpret_cast(ArrayId.get_user()); + ScopArrayInfo *SAI = *ArrayId.get_user(); // Create a dummy SCEV access, to be replaced anyway. SmallVector Sizes; Index: lib/Transform/MaximalStaticExpansion.cpp =================================================================== --- lib/Transform/MaximalStaticExpansion.cpp +++ lib/Transform/MaximalStaticExpansion.cpp @@ -149,8 +149,7 @@ auto TmpMapDomainId = Map.get_space().domain().unwrap().range().get_tuple_id(isl::dim::set); - ScopArrayInfo *UserSAI = - static_cast(TmpMapDomainId.get_user()); + ScopArrayInfo *UserSAI = *TmpMapDomainId.get_user(); if (SAI != UserSAI) continue; Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -688,7 +688,7 @@ static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, MatMulInfoTy &MMI) { auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimId.get_user()); + auto *Stmt = *InputDimId.get_user(); unsigned OutDimNum = PartialSchedule.dim(isl::dim::out); assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " "and, consequently, the corresponding scheduling " @@ -773,7 +773,7 @@ static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, MatMulInfoTy &MMI) { auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimsId.get_user()); + auto *Stmt = *InputDimsId.get_user(); if (Stmt->size() <= 1) return false; @@ -1095,7 +1095,7 @@ MacroKernelParamsTy MacroParams, MatMulInfoTy &MMI) { auto InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimsId.get_user()); + auto *Stmt = *InputDimsId.get_user(); // Create a copy statement that corresponds to the memory access to the // matrix B, the second operand of the matrix multiplication. Index: lib/Transform/Simplify.cpp =================================================================== --- lib/Transform/Simplify.cpp +++ lib/Transform/Simplify.cpp @@ -311,9 +311,9 @@ // Iterate through the candidates. for (isl::map Map : Filtered.get_map_list()) { - MemoryAccess *OtherMA = (MemoryAccess *)Map.get_space() + MemoryAccess *OtherMA = *Map.get_space() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); isl::map OtherAccRel = OtherMA->getLatestAccessRelation().intersect_domain(Domain); @@ -352,21 +352,21 @@ SmallPtrSet TouchedAccesses; for (isl::map Map : FutureWrites.intersect_domain(AccRelWrapped).get_map_list()) { - MemoryAccess *MA = (MemoryAccess *)Map.get_space() + MemoryAccess *MA = *Map.get_space() .range() .unwrap() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); TouchedAccesses.insert(MA); } isl::union_map NewFutureWrites = isl::union_map::empty(FutureWrites.get_space()); for (isl::map FutureWrite : FutureWrites.get_map_list()) { - MemoryAccess *MA = (MemoryAccess *)FutureWrite.get_space() + MemoryAccess *MA = *FutureWrite.get_space() .range() .unwrap() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); if (!TouchedAccesses.count(MA)) NewFutureWrites = NewFutureWrites.add_map(FutureWrite); } Index: lib/Transform/ZoneAlgo.cpp =================================================================== --- lib/Transform/ZoneAlgo.cpp +++ lib/Transform/ZoneAlgo.cpp @@ -706,7 +706,7 @@ isl::map ZoneAlgorithm::getScalarReachingDefinition(isl::set DomainDef) { auto DomId = DomainDef.get_tuple_id(); - auto *Stmt = static_cast(isl_id_get_user(DomId.get())); + auto *Stmt = *DomId.get_user(); auto StmtResult = getScalarReachingDefinition(Stmt); @@ -848,8 +848,8 @@ continue; } - auto *PHI = dyn_cast(static_cast( - RangeSpace.unwrap().get_tuple_id(isl::dim::out).get_user())); + auto *PHI = dyn_cast( + *RangeSpace.unwrap().get_tuple_id(isl::dim::out).get_user()); // If no normalization is necessary, then the ValInst stands for itself. if (!ComputedPHIs.count(PHI)) { @@ -924,14 +924,17 @@ isl::id OutTupleId = Unwrapped.get_tuple_id(isl::dim::out); if (OutTupleId.is_null()) return isl::boolean(); - auto *PHI = dyn_cast(static_cast(OutTupleId.get_user())); + auto *User = OutTupleId.get_user(); + if (!User) + return true; + auto *PHI = dyn_cast(*User); if (!PHI) return true; isl::id InTupleId = Unwrapped.get_tuple_id(isl::dim::in); if (OutTupleId.is_null()) return isl::boolean(); - auto *IncomingStmt = static_cast(InTupleId.get_user()); + auto *IncomingStmt = *InTupleId.get_user(); MemoryAccess *PHIRead = IncomingStmt->lookupPHIReadOf(PHI); if (!isNormalizable(PHIRead)) return true;