diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3349,6 +3349,7 @@ InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) { + // TODO: handle the case where XElemTy is not byte-sized or not a power of 2. bool DoCmpExch = (RMWOp == AtomicRMWInst::BAD_BINOP) || (RMWOp == AtomicRMWInst::FAdd) || (RMWOp == AtomicRMWInst::FSub) || diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -966,6 +966,102 @@ return success(); } +/// Converts an LLVM dialect binary operation to the corresponding enum value +/// for `atomicrmw` supported binary operation. +llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) { + return llvm::TypeSwitch(&op) + .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; }) + .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; }) + .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; }) + .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; }) + .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; }) + .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; }) + .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; }) + .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; }) + .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; }) + .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP); +} + +/// Converts an OpenMP update operation using OpenMPIRBuilder. +static LogicalResult +convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::DISubprogram *subprogram = + builder.GetInsertBlock()->getParent()->getSubprogram(); + const llvm::DILocation *diLoc = + moduleTranslation.translateLoc(opInst.getLoc(), subprogram); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), + llvm::DebugLoc(diLoc)); + + auto &innerOpList = opInst.region().front().getOperations(); + if (innerOpList.size() != 2) { + return opInst.emitError() + << "the update region must have only two operations " + "(binop and terminator)"; + } + + // Convert values and types. + Operation &innerUpdateOp = innerOpList.front(); + bool isXBinopExpr = false; + if (innerUpdateOp.getNumOperands() == 2) { + isXBinopExpr = + innerUpdateOp.getOperand(0) == opInst.getRegion().getArgument(0); + } + mlir::Value mlirExpr = (isXBinopExpr ? innerUpdateOp.getOperand(1) + : innerUpdateOp.getOperand(0)); + llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); + llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.x()); + LLVM::LLVMPointerType mlirXType = + opInst.x().getType().cast(); + llvm::Type *llvmXElementType = + moduleTranslation.convertType(mlirXType.getElementType()); + llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType, + /*isSigned=*/false, + /*isVolatile=*/false}; + + llvm::AtomicRMWInst::BinOp binop = convertBinOpToAtomic(innerUpdateOp); + llvm::AtomicOrdering atomicOrdering = + convertAtomicOrdering(opInst.memory_order()); + + // Generate update code. + LogicalResult updateGenStatus = success(); + auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus]( + llvm::Value *atomicx, + llvm::IRBuilder<> &builder) -> llvm::Value * { + Block &bb = *opInst.region().begin(); + moduleTranslation.mapValue(*opInst.region().args_begin(), atomicx); + moduleTranslation.mapBlock(&bb, builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(bb, true, builder))) { + updateGenStatus = (opInst.emitError() + << "unable to convert update operation to llvm IR"); + return nullptr; + } + omp::YieldOp yieldop = dyn_cast(bb.getTerminator()); + assert(yieldop && yieldop.results().size() == 1 && + "terminator must be omp.yield op and it must have exactly one " + "argument"); + return moduleTranslation.lookupValue(yieldop.results()[0]); + }; + + // Handle ambiguous alloca, if any. + auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + llvm::UnreachableInst *unreachableInst; + if (allocaIP.getPoint() == ompLoc.IP.getPoint()) { + // Same point => split basic block and make them unambigous. + unreachableInst = builder.CreateUnreachable(); + builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock( + unreachableInst, "alloca_split")); + ompLoc.IP = builder.saveIP(); + unreachableInst->removeFromParent(); + } + builder.restoreIP(ompBuilder->createAtomicUpdate( + ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX, + llvmExpr, atomicOrdering, binop, updateFn, isXBinopExpr)); + return updateGenStatus; +} + /// Converts an OpenMP reduction operation using OpenMPIRBuilder. Expects the /// mapping between reduction variables and their private equivalents to have /// been stored on the ModuleTranslation stack. Currently only supports @@ -1093,6 +1189,9 @@ .Case([&](omp::AtomicWriteOp) { return convertOmpAtomicWrite(*op, builder, moduleTranslation); }) + .Case([&](omp::AtomicUpdateOp op) { + return convertOmpAtomicUpdate(op, builder, moduleTranslation); + }) .Case([&](omp::SectionsOp) { return convertOmpSections(*op, builder, moduleTranslation); }) diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -826,6 +826,85 @@ // ----- +// Checking simple atomicrmw and cmpxchg based translation. This also checks for +// ambigous alloca insert point by putting llvm.mul as the first update operation. +// CHECK-LABEL: @omp_atomic_update +// CHECK-SAME: (i32* %[[x:.*]], i32 %[[expr:.*]], i1* %[[xbool:.*]], i1 %[[exprbool:.*]]) +llvm.func @omp_atomic_update(%x:!llvm.ptr, %expr: i32, %xbool: !llvm.ptr, %exprbool: i1) { + // CHECK: %[[t1:.*]] = mul i32 %[[x_old:.*]], %[[expr]] + // CHECK: store i32 %[[t1]], i32* %[[x_new:.*]] + // CHECK: %[[t2:.*]] = load i32, i32* %[[x_new]] + // CHECK: cmpxchg i32* %[[x]], i32 %[[x_old]], i32 %[[t2]] + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = llvm.mul %xval, %expr : i32 + omp.yield(%newval : i32) + } + // CHECK: atomicrmw add i32* %[[x]], i32 %[[expr]] monotonic + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = llvm.add %xval, %expr : i32 + omp.yield(%newval : i32) + } + llvm.return +} + +// ----- + +// Checking an order-dependent operation when the order is `expr binop x` +// CHECK-LABEL: @omp_atomic_update_ordering +// CHECK-SAME: (i32* %[[x:.*]], i32 %[[expr:.*]]) +llvm.func @omp_atomic_update_ordering(%x:!llvm.ptr, %expr: i32) { + // CHECK: %[[t1:.*]] = shl i32 %[[expr]], %[[x_old:[^ ,]*]] + // CHECK: store i32 %[[t1]], i32* %[[x_new:.*]] + // CHECK: %[[t2:.*]] = load i32, i32* %[[x_new]] + // CHECK: cmpxchg i32* %[[x]], i32 %[[x_old]], i32 %[[t2]] + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = llvm.shl %expr, %xval : i32 + omp.yield(%newval : i32) + } + llvm.return +} + +// ----- + +// Checking an order-dependent operation when the order is `x binop expr` +// CHECK-LABEL: @omp_atomic_update_ordering +// CHECK-SAME: (i32* %[[x:.*]], i32 %[[expr:.*]]) +llvm.func @omp_atomic_update_ordering(%x:!llvm.ptr, %expr: i32) { + // CHECK: %[[t1:.*]] = shl i32 %[[x_old:.*]], %[[expr]] + // CHECK: store i32 %[[t1]], i32* %[[x_new:.*]] + // CHECK: %[[t2:.*]] = load i32, i32* %[[x_new]] + // CHECK: cmpxchg i32* %[[x]], i32 %[[x_old]], i32 %[[t2]] monotonic + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = llvm.shl %xval, %expr : i32 + omp.yield(%newval : i32) + } + llvm.return +} + +// ----- + +// Checking intrinsic translation. +// CHECK-LABEL: @omp_atomic_update_intrinsic +// CHECK-SAME: (i32* %[[x:.*]], i32 %[[expr:.*]]) +llvm.func @omp_atomic_update_intrinsic(%x:!llvm.ptr, %expr: i32) { + // CHECK: %[[t1:.*]] = call i32 @llvm.smax.i32(i32 %[[x_old:.*]], i32 %[[expr]]) + // CHECK: store i32 %[[t1]], i32* %[[x_new:.*]] + // CHECK: %[[t2:.*]] = load i32, i32* %[[x_new]] + // CHECK: cmpxchg i32* %[[x]], i32 %[[x_old]], i32 %[[t2]] + omp.atomic.update %x : !llvm.ptr { + ^bb0(%xval: i32): + %newval = "llvm.intr.smax"(%xval, %expr) : (i32, i32) -> i32 + omp.yield(%newval : i32) + } + llvm.return +} + +// ----- + // CHECK-LABEL: @omp_sections_empty llvm.func @omp_sections_empty() -> () { omp.sections {