Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -539,24 +539,27 @@ function_ref; /// Functions used to generate atomic reductions. Such functions take two - /// Values representing pointers to LHS and RHS of the reduction. They are - /// expected to atomically update the LHS to the reduced value. + /// Values representing pointers to LHS and RHS of the reduction, as well as + /// the element type of these pointers. They are expected to atomically + /// update the LHS to the reduced value. using AtomicReductionGenTy = - function_ref; + function_ref; /// Information about an OpenMP reduction. struct ReductionInfo { - ReductionInfo(Value *Variable, Value *PrivateVariable, + ReductionInfo(Type *ElementType, Value *Variable, Value *PrivateVariable, ReductionGenTy ReductionGen, AtomicReductionGenTy AtomicReductionGen) - : Variable(Variable), PrivateVariable(PrivateVariable), - ReductionGen(ReductionGen), AtomicReductionGen(AtomicReductionGen) {} - - /// Returns the type of the element being reduced. - Type *getElementType() const { - return Variable->getType()->getPointerElementType(); + : ElementType(ElementType), Variable(Variable), + PrivateVariable(PrivateVariable), ReductionGen(ReductionGen), + AtomicReductionGen(AtomicReductionGen) { + assert(cast(Variable->getType()) + ->isOpaqueOrPointeeTypeMatches(ElementType) && "Invalid elem type"); } + /// Reduction element type, must match pointee type of variable. + Type *ElementType; + /// Reduction variable of pointer type. Value *Variable; Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp =================================================================== --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1156,7 +1156,7 @@ Builder.SetInsertPoint(NonAtomicRedBlock); for (auto En : enumerate(ReductionInfos)) { const ReductionInfo &RI = En.value(); - Type *ValueType = RI.getElementType(); + Type *ValueType = RI.ElementType; Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable, "red.value." + Twine(En.index())); Value *PrivateRedValue = @@ -1181,8 +1181,8 @@ Builder.SetInsertPoint(AtomicRedBlock); if (CanGenerateAtomic) { for (const ReductionInfo &RI : ReductionInfos) { - Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable, - RI.PrivateVariable)); + Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType, + RI.Variable, RI.PrivateVariable)); if (!Builder.GetInsertBlock()) return InsertPointTy(); } @@ -1207,13 +1207,13 @@ RedArrayTy, LHSArrayPtr, 0, En.index()); Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr); Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType()); - Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr); + Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr); Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64( RedArrayTy, RHSArrayPtr, 0, En.index()); Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr); Value *RHSPtr = Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType()); - Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr); + Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr); Value *Reduced; Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced)); if (!Builder.GetInsertBlock()) Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp =================================================================== --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -3028,10 +3028,10 @@ } static OpenMPIRBuilder::InsertPointTy -sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) { +sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS, + Value *RHS) { IRBuilder<> Builder(IP.getBlock(), IP.getPoint()); - Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(), - RHS, "red.partial"); + Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial"); Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, LHS, Partial, None, AtomicOrdering::Monotonic); return Builder.saveIP(); @@ -3046,10 +3046,10 @@ } static OpenMPIRBuilder::InsertPointTy -xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) { +xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS, + Value *RHS) { IRBuilder<> Builder(IP.getBlock(), IP.getPoint()); - Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(), - RHS, "red.partial"); + Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial"); Builder.CreateAtomicRMW(AtomicRMWInst::Xor, LHS, Partial, None, AtomicOrdering::Monotonic); return Builder.saveIP(); @@ -3081,13 +3081,15 @@ // Create variables to be reduced. InsertPointTy OuterAllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); + Type *SumType = Builder.getFloatTy(); + Type *XorType = Builder.getInt32Ty(); Value *SumReduced; Value *XorReduced; { IRBuilderBase::InsertPointGuard Guard(Builder); Builder.restoreIP(OuterAllocaIP); - SumReduced = Builder.CreateAlloca(Builder.getFloatTy()); - XorReduced = Builder.CreateAlloca(Builder.getInt32Ty()); + SumReduced = Builder.CreateAlloca(SumType); + XorReduced = Builder.CreateAlloca(XorType); } // Store initial values of reductions into global variables. @@ -3109,12 +3111,8 @@ Value *TID = OMPBuilder.getOrCreateThreadID(Ident); Value *SumLocal = Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local"); - Value *SumPartial = - Builder.CreateLoad(SumReduced->getType()->getPointerElementType(), - SumReduced, "sum.partial"); - Value *XorPartial = - Builder.CreateLoad(XorReduced->getType()->getPointerElementType(), - XorReduced, "xor.partial"); + Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial"); + Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial"); Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum"); Value *Xor = Builder.CreateXor(XorPartial, TID, "xor"); Builder.CreateStore(Sum, SumReduced); @@ -3164,8 +3162,8 @@ Builder.restoreIP(AfterIP); OpenMPIRBuilder::ReductionInfo ReductionInfos[] = { - {SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}, - {XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}; + {SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}, + {XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}; OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos); @@ -3319,13 +3317,15 @@ // Create variables to be reduced. InsertPointTy OuterAllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); + Type *SumType = Builder.getFloatTy(); + Type *XorType = Builder.getInt32Ty(); Value *SumReduced; Value *XorReduced; { IRBuilderBase::InsertPointGuard Guard(Builder); Builder.restoreIP(OuterAllocaIP); - SumReduced = Builder.CreateAlloca(Builder.getFloatTy()); - XorReduced = Builder.CreateAlloca(Builder.getInt32Ty()); + SumReduced = Builder.CreateAlloca(SumType); + XorReduced = Builder.CreateAlloca(XorType); } // Store initial values of reductions into global variables. @@ -3344,9 +3344,7 @@ Value *TID = OMPBuilder.getOrCreateThreadID(Ident); Value *SumLocal = Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local"); - Value *SumPartial = - Builder.CreateLoad(SumReduced->getType()->getPointerElementType(), - SumReduced, "sum.partial"); + Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial"); Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum"); Builder.CreateStore(Sum, SumReduced); @@ -3364,9 +3362,7 @@ Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc); Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr); Value *TID = OMPBuilder.getOrCreateThreadID(Ident); - Value *XorPartial = - Builder.CreateLoad(XorReduced->getType()->getPointerElementType(), - XorReduced, "xor.partial"); + Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial"); Value *Xor = Builder.CreateXor(XorPartial, TID, "xor"); Builder.CreateStore(Xor, XorReduced); @@ -3421,10 +3417,10 @@ OMPBuilder.createReductions( FirstBodyIP, FirstBodyAllocaIP, - {{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}}); + {{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}}); OMPBuilder.createReductions( SecondBodyIP, SecondBodyAllocaIP, - {{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}); + {{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}}); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -415,7 +415,8 @@ llvm::Value *&)>; using OwningAtomicReductionGen = std::function; + llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *, + llvm::Value *)>; } // namespace /// Create an OpenMPIRBuilder-compatible reduction generator for the given @@ -463,7 +464,7 @@ // avoid the dangling reference after the parent function returns. OwningAtomicReductionGen atomicGen = [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, - llvm::Value *lhs, llvm::Value *rhs) mutable { + llvm::Type *, llvm::Value *lhs, llvm::Value *rhs) mutable { Region &atomicRegion = decl.atomicReductionRegion(); moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs); moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs); @@ -763,8 +764,10 @@ llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr; if (owningAtomicReductionGens[i]) atomicGen = owningAtomicReductionGens[i]; + llvm::Value *variable = + moduleTranslation.lookupValue(loop.reduction_vars()[i]); reductionInfos.push_back( - {moduleTranslation.lookupValue(loop.reduction_vars()[i]), + {variable->getType()->getPointerElementType(), variable, privateReductionVariables[i], owningReductionGens[i], atomicGen}); }