Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -501,9 +501,7 @@ } const ScopArrayInfo *ScopArrayInfo::getFromId(isl::id Id) { - void *User = Id.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + return Id.get_user(); } void MemoryAccess::wrapConstantDimensions() { @@ -666,16 +664,12 @@ const ScopArrayInfo *MemoryAccess::getOriginalScopArrayInfo() const { isl::id ArrayId = getArrayId(); - void *User = ArrayId.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + return ScopArrayInfo::getFromId(ArrayId); } const ScopArrayInfo *MemoryAccess::getLatestScopArrayInfo() const { isl::id ArrayId = getLatestArrayId(); - void *User = ArrayId.get_user(); - const ScopArrayInfo *SAI = static_cast(User); - return SAI; + return ScopArrayInfo::getFromId(ArrayId); } isl::id MemoryAccess::getOriginalArrayId() const { @@ -1212,7 +1206,7 @@ assert(NewAccessSpace.has_tuple_id(isl::dim::set) && "Must specify the array that is accessed"); isl::id NewArrayId = NewAccessSpace.get_tuple_id(isl::dim::set); - auto *SAI = static_cast(NewArrayId.get_user()); + auto *SAI = NewArrayId.get_user(); assert(SAI && "Must set a ScopArrayInfo"); if (SAI->isArrayKind() && SAI->getBasePtrOriginSAI()) { @@ -1998,8 +1992,7 @@ ParameterName = getIslCompatibleName("", ParameterName, ""); } - isl::id Id = isl::id::alloc(getIslCtx(), ParameterName, - const_cast((const void *)Parameter)); + isl::id Id = isl::id::alloc(getIslCtx(), ParameterName, Parameter); ParameterIds[Parameter] = Id; } @@ -2086,8 +2079,7 @@ if (!NewParams.empty()) { for (unsigned u = 0; u < isl_set_n_param(AssumptionCtx); u++) { auto *Id = isl_set_get_dim_id(AssumptionCtx, isl_dim_param, u); - auto *Param = static_cast(isl_id_get_user(Id)); - isl_id_free(Id); + auto *Param = isl::manage(Id).get_user(); if (!NewParams.count(Param)) continue; Index: lib/CodeGen/IslAst.cpp =================================================================== --- lib/CodeGen/IslAst.cpp +++ lib/CodeGen/IslAst.cpp @@ -263,17 +263,17 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, void *User) { AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; - IslAstUserPayload *Payload = new IslAstUserPayload(); - isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); - Id = isl_id_set_free_user(Id, freeIslAstUserPayload); - BuildInfo->LastForNodeId = Id; + auto Id = isl::id::alloc(isl::ctx(isl_ast_build_get_ctx(Build)), "", + IslAstUserPayload()); + auto &Payload = Id.get_user(); + BuildInfo->LastForNodeId = Id.get(); // Test for parallelism only if we are not already inside a parallel loop if (!BuildInfo->InParallelFor) - BuildInfo->InParallelFor = Payload->IsOutermostParallel = - astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); + BuildInfo->InParallelFor = Payload.IsOutermostParallel = + astScheduleDimIsParallel(Build, BuildInfo->Deps, &Payload); - return Id; + return Id.release(); } // This method is executed after the construction of a for node. @@ -288,7 +288,8 @@ void *User) { isl_id *Id = isl_ast_node_get_annotation(Node); assert(Id && "Post order visit assumes annotated for nodes"); - IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); + IslAstUserPayload *Payload = + isl::manage_copy(Id).get_user(); assert(Payload && "Post order visit assumes annotated for nodes"); AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; @@ -345,13 +346,13 @@ void *User) { assert(!isl_ast_node_get_annotation(Node) && "Node already annotated"); - IslAstUserPayload *Payload = new IslAstUserPayload(); - isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); - Id = isl_id_set_free_user(Id, freeIslAstUserPayload); + auto Id = isl::id::alloc(isl::ctx(isl_ast_build_get_ctx(Build)), "", + IslAstUserPayload()); + auto &Payload = Id.get_user(); - Payload->Build = isl_ast_build_copy(Build); + Payload.Build = isl_ast_build_copy(Build); - return isl_ast_node_set_annotation(Node, Id); + return isl_ast_node_set_annotation(Node, Id.release()); } // Build alias check condition given a pair of minimal/maximal access. @@ -606,8 +607,7 @@ isl_id *Id = isl_ast_node_get_annotation(Node); if (!Id) return nullptr; - IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); - isl_id_free(Id); + IslAstUserPayload *Payload = isl::manage(Id).get_user(); return Payload; } @@ -693,7 +693,7 @@ isl::ast_expr NodeExpr = AstNode.user_get_expr(); isl::ast_expr CallExpr = NodeExpr.get_op_arg(0); isl::id CallExprId = CallExpr.get_id(); - ScopStmt *AccessStmt = (ScopStmt *)CallExprId.get_user(); + ScopStmt *AccessStmt = CallExprId.get_user(); P = isl_printer_start_line(P); P = isl_printer_print_str(P, AccessStmt->getBaseName()); Index: lib/CodeGen/IslNodeBuilder.cpp =================================================================== --- lib/CodeGen/IslNodeBuilder.cpp +++ lib/CodeGen/IslNodeBuilder.cpp @@ -292,7 +292,7 @@ static void addReferencesFromStmtSet(isl::set Set, struct SubtreeReferences *UserPtr) { isl::id Id = Set.get_tuple_id(); - auto *Stmt = static_cast(Id.get_user()); + auto *Stmt = Id.get_user(); return addReferencesFromStmt(Stmt, UserPtr); } @@ -406,7 +406,7 @@ isl_ast_expr *StmtExpr = isl_ast_expr_get_op_arg(Expr, 0); isl_id *Id = isl_ast_expr_get_id(StmtExpr); isl_ast_expr_free(StmtExpr); - ScopStmt *Stmt = (ScopStmt *)isl_id_get_user(Id); + ScopStmt *Stmt = isl::manage_copy(Id).get_user(); std::vector VLTS(IVS.size()); isl_union_set *Domain = isl_union_set_from_set(Stmt->getDomain().release()); @@ -440,7 +440,7 @@ return; } if (strcmp(isl_id_get_name(Id), "Inter iteration alias-free") == 0) { - auto *BasePtr = static_cast(isl_id_get_user(Id)); + auto *BasePtr = isl::manage_copy(Id).get_user(); Annotator.addInterIterationAliasFreeBasePtr(BasePtr); } create(Child); @@ -761,8 +761,7 @@ isl::ast_expr StmtExpr = Expr.get_op_arg(0); isl::id Id = StmtExpr.get_id(); - ScopStmt *Stmt = - static_cast(isl_id_get_user(Id.get())); + ScopStmt *Stmt = Id.get_user(); isl::set StmtDom = Stmt->getDomain(); for (auto *MA : *Stmt) { if (MA->isLatestPartialAccess()) @@ -993,7 +992,7 @@ LTS.insert(OutsideLoopIterations.begin(), OutsideLoopIterations.end()); - Stmt = (ScopStmt *)isl_id_get_user(Id); + Stmt = isl::manage_copy(Id).get_user(); auto *NewAccesses = createNewAccesses(Stmt, User); if (Stmt->isCopyStmt()) { generateCopyStmt(Stmt, NewAccesses); @@ -1049,7 +1048,7 @@ bool IslNodeBuilder::materializeValue(isl_id *Id) { // If the Id is already mapped, skip it. if (!IDToValue.count(Id)) { - auto *ParamSCEV = (const SCEV *)isl_id_get_user(Id); + auto *ParamSCEV = isl::manage_copy(Id).get_user(); Value *V = nullptr; // Parameters could refer to invariant loads that need to be Index: lib/CodeGen/PPCGCodeGeneration.cpp =================================================================== --- lib/CodeGen/PPCGCodeGeneration.cpp +++ lib/CodeGen/PPCGCodeGeneration.cpp @@ -1175,7 +1175,7 @@ isl_ast_expr *Expr = isl_ast_node_user_get_expr(TransferStmt); isl_ast_expr *Arg = isl_ast_expr_get_op_arg(Expr, 0); isl_id *Id = isl_ast_expr_get_id(Arg); - auto Array = (gpu_array_info *)isl_id_get_user(Id); + auto Array = isl::manage_copy(Id).get_user(); auto ScopArray = (ScopArrayInfo *)(Array->user); Value *Size = getArraySize(Array); @@ -1264,7 +1264,7 @@ isl_id *Anno = isl_ast_node_get_annotation(UserStmt); struct ppcg_kernel_stmt *KernelStmt = - (struct ppcg_kernel_stmt *)isl_id_get_user(Anno); + isl::manage_copy(Anno).get_user(); isl_id_free(Anno); switch (KernelStmt->type) { @@ -1378,7 +1378,7 @@ return isl_bool_true; Id = isl_ast_node_get_annotation(Node); - auto *KernelStmt = (ppcg_kernel_stmt *)isl_id_get_user(Id); + auto *KernelStmt = isl::manage_copy(Id).get_user(); auto Stmt = (ScopStmt *)KernelStmt->u.d.stmt->stmt; isl_id_free(Id); @@ -1780,7 +1780,7 @@ } void GPUNodeBuilder::createKernel(__isl_take isl_ast_node *KernelStmt) { isl_id *Id = isl_ast_node_get_annotation(KernelStmt); - ppcg_kernel *Kernel = (ppcg_kernel *)isl_id_get_user(Id); + ppcg_kernel *Kernel = isl::manage_copy(Id).get_user(); isl_id_free(Id); isl_ast_node_free(KernelStmt); @@ -3146,7 +3146,7 @@ return P; } - auto Kernel = (struct ppcg_kernel *)isl_id_get_user(Id); + auto Kernel = isl::manage_copy(Id).get_user(); isl_id_free(Id); Data->Kernels.push_back(Kernel); } Index: lib/External/isl/include/isl/isl-noexceptions.h =================================================================== --- lib/External/isl/include/isl/isl-noexceptions.h +++ lib/External/isl/include/isl/isl-noexceptions.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -976,10 +977,11 @@ inline std::string to_str() const; inline void dump() const; - static inline isl::id alloc(isl::ctx ctx, const std::string &name, void * user); + template + static inline isl::id alloc(isl::ctx ctx, const std::string &name, T &&user); inline uint32_t get_hash() const; inline std::string get_name() const; - inline void * get_user() const; + template inline T &get_user() const; }; // declarations for isl::id_list @@ -6956,9 +6958,12 @@ } -isl::id id::alloc(isl::ctx ctx, const std::string &name, void * user) +template +isl::id id::alloc(isl::ctx ctx, const std::string &name, T &&user) { - auto res = isl_id_alloc(ctx.release(), name.c_str(), user); + auto res = isl_id_alloc(ctx.release(), name.c_str(), + new llvm::Any(std::forward(user))); + res = isl_id_set_free_user(res, [](void *user) { delete user; }); return manage(res); } @@ -6975,10 +6980,11 @@ return tmp; } -void * id::get_user() const -{ - auto res = isl_id_get_user(get()); - return res; +template T &id::get_user() const { + auto Any = static_cast(isl_id_get_user(get())); + T *User = llvm::any_cast(Any); + ISLPP_ASSERT(User, "Tried to fetch wrong user type from isl_id"); + return *User; } // implementations for isl::id_list Index: lib/Transform/ForwardOpTree.cpp =================================================================== --- lib/Transform/ForwardOpTree.cpp +++ lib/Transform/ForwardOpTree.cpp @@ -246,7 +246,7 @@ for (isl::map Map : MustKnown.get_map_list()) { // Get the array this is accessing. isl::id ArrayId = Map.get_tuple_id(isl::dim::out); - ScopArrayInfo *SAI = static_cast(ArrayId.get_user()); + ScopArrayInfo *SAI = ArrayId.get_user(); // No support for generation of indirect array accesses. if (SAI->getBasePtrOriginSAI()) @@ -348,7 +348,7 @@ MemoryAccess *makeReadArrayAccess(ScopStmt *Stmt, LoadInst *LI, isl::map AccessRelation) { isl::id ArrayId = AccessRelation.get_tuple_id(isl::dim::out); - ScopArrayInfo *SAI = reinterpret_cast(ArrayId.get_user()); + ScopArrayInfo *SAI = ArrayId.get_user(); // Create a dummy SCEV access, to be replaced anyway. SmallVector Sizes; Index: lib/Transform/MaximalStaticExpansion.cpp =================================================================== --- lib/Transform/MaximalStaticExpansion.cpp +++ lib/Transform/MaximalStaticExpansion.cpp @@ -149,8 +149,7 @@ auto TmpMapDomainId = Map.get_space().domain().unwrap().range().get_tuple_id(isl::dim::set); - ScopArrayInfo *UserSAI = - static_cast(TmpMapDomainId.get_user()); + ScopArrayInfo *UserSAI = TmpMapDomainId.get_user(); if (SAI != UserSAI) continue; Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -688,7 +688,7 @@ static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, MatMulInfoTy &MMI) { auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimId.get_user()); + auto *Stmt = InputDimId.get_user(); unsigned OutDimNum = PartialSchedule.dim(isl::dim::out); assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " "and, consequently, the corresponding scheduling " @@ -773,7 +773,7 @@ static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, MatMulInfoTy &MMI) { auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimsId.get_user()); + auto *Stmt = InputDimsId.get_user(); if (Stmt->size() <= 1) return false; @@ -1095,7 +1095,7 @@ MacroKernelParamsTy MacroParams, MatMulInfoTy &MMI) { auto InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in); - auto *Stmt = static_cast(InputDimsId.get_user()); + auto *Stmt = InputDimsId.get_user(); // Create a copy statement that corresponds to the memory access to the // matrix B, the second operand of the matrix multiplication. Index: lib/Transform/Simplify.cpp =================================================================== --- lib/Transform/Simplify.cpp +++ lib/Transform/Simplify.cpp @@ -311,9 +311,9 @@ // Iterate through the candidates. for (isl::map Map : Filtered.get_map_list()) { - MemoryAccess *OtherMA = (MemoryAccess *)Map.get_space() + MemoryAccess *OtherMA = Map.get_space() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); isl::map OtherAccRel = OtherMA->getLatestAccessRelation().intersect_domain(Domain); @@ -352,21 +352,21 @@ SmallPtrSet TouchedAccesses; for (isl::map Map : FutureWrites.intersect_domain(AccRelWrapped).get_map_list()) { - MemoryAccess *MA = (MemoryAccess *)Map.get_space() + MemoryAccess *MA = Map.get_space() .range() .unwrap() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); TouchedAccesses.insert(MA); } isl::union_map NewFutureWrites = isl::union_map::empty(FutureWrites.get_space()); for (isl::map FutureWrite : FutureWrites.get_map_list()) { - MemoryAccess *MA = (MemoryAccess *)FutureWrite.get_space() + MemoryAccess *MA = FutureWrite.get_space() .range() .unwrap() .get_tuple_id(isl::dim::out) - .get_user(); + .get_user(); if (!TouchedAccesses.count(MA)) NewFutureWrites = NewFutureWrites.add_map(FutureWrite); } Index: lib/Transform/ZoneAlgo.cpp =================================================================== --- lib/Transform/ZoneAlgo.cpp +++ lib/Transform/ZoneAlgo.cpp @@ -706,7 +706,7 @@ isl::map ZoneAlgorithm::getScalarReachingDefinition(isl::set DomainDef) { auto DomId = DomainDef.get_tuple_id(); - auto *Stmt = static_cast(isl_id_get_user(DomId.get())); + auto *Stmt = DomId.get_user(); auto StmtResult = getScalarReachingDefinition(Stmt); @@ -848,8 +848,8 @@ continue; } - auto *PHI = dyn_cast(static_cast( - RangeSpace.unwrap().get_tuple_id(isl::dim::out).get_user())); + auto *PHI = dyn_cast( + RangeSpace.unwrap().get_tuple_id(isl::dim::out).get_user()); // If no normalization is necessary, then the ValInst stands for itself. if (!ComputedPHIs.count(PHI)) { @@ -924,14 +924,14 @@ isl::id OutTupleId = Unwrapped.get_tuple_id(isl::dim::out); if (OutTupleId.is_null()) return isl::boolean(); - auto *PHI = dyn_cast(static_cast(OutTupleId.get_user())); + auto *PHI = dyn_cast(OutTupleId.get_user()); if (!PHI) return true; isl::id InTupleId = Unwrapped.get_tuple_id(isl::dim::in); if (OutTupleId.is_null()) return isl::boolean(); - auto *IncomingStmt = static_cast(InTupleId.get_user()); + auto *IncomingStmt = InTupleId.get_user(); MemoryAccess *PHIRead = IncomingStmt->lookupPHIReadOf(PHI); if (!isNormalizable(PHIRead)) return true;