diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1126,27 +1126,33 @@ // Convert values and types. auto &innerOpList = opInst.getRegion().front().getOperations(); - if (innerOpList.size() != 2) - return opInst.emitError("exactly two operations are allowed inside an " - "atomic update region while lowering to LLVM IR"); - Operation &innerUpdateOp = innerOpList.front(); - - if (innerUpdateOp.getNumOperands() != 2 || - !llvm::is_contained(innerUpdateOp.getOperands(), - opInst.getRegion().getArgument(0))) - return opInst.emitError( - "the update operation inside the region must be a binary operation and " - "that update operation must have the region argument as an operand"); - - llvm::AtomicRMWInst::BinOp binop = convertBinOpToAtomic(innerUpdateOp); - - bool isXBinopExpr = - innerUpdateOp.getNumOperands() > 0 && - innerUpdateOp.getOperand(0) == opInst.getRegion().getArgument(0); + bool isRegionArgUsed{false}; + llvm::AtomicRMWInst::BinOp binop; + bool isXBinopExpr{false}, isUpdateOpPresent{false}; + mlir::Value mlirExpr; + for (Operation &innerOp : innerOpList) { + if (innerOp.getNumOperands() == 2) { + isUpdateOpPresent = true; + binop = convertBinOpToAtomic(innerOp); + isUpdateOpPresent = true; + if (!llvm::is_contained(innerOp.getOperands(), + opInst.getRegion().getArgument(0))) + continue; + isRegionArgUsed = true; + isXBinopExpr = innerOp.getNumOperands() > 0 && + innerOp.getOperand(0) == opInst.getRegion().getArgument(0); + mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); + break; + } + } + if (!isUpdateOpPresent) + return opInst.emitError("no atomic update operation found" + " inside atomic.update region"); + if (!isRegionArgUsed) + return opInst.emitError("the update operation inside the region " + "must have the region argument as an operand"); - mlir::Value mlirExpr = (isXBinopExpr ? innerUpdateOp.getOperand(1) - : innerUpdateOp.getOperand(0)); llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.getX()); llvm::Type *llvmXElementType = moduleTranslation.convertType( @@ -1209,25 +1215,33 @@ isPostfixUpdate = atomicCaptureOp.getSecondOp() == atomicCaptureOp.getAtomicUpdateOp().getOperation(); auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations(); - if (innerOpList.size() != 2) - return atomicUpdateOp.emitError( - "exactly two operations are allowed inside an " - "atomic update region while lowering to LLVM IR"); - Operation *innerUpdateOp = atomicUpdateOp.getFirstOp(); - if (innerUpdateOp->getNumOperands() != 2 || - !llvm::is_contained(innerUpdateOp->getOperands(), - atomicUpdateOp.getRegion().getArgument(0))) - return atomicUpdateOp.emitError( - "the update operation inside the region must be a binary operation " - "and that update operation must have the region argument as an " - "operand"); - binop = convertBinOpToAtomic(*innerUpdateOp); + bool isRegionArgUsed{false}, isUpdateOpPresent{false}; + for (Operation &innerOp : innerOpList) { + if (innerOp.getNumOperands() == 2) { + isUpdateOpPresent = true; + binop = convertBinOpToAtomic(innerOp); + if (!llvm::is_contained(innerOp.getOperands(), + atomicUpdateOp.getRegion().getArgument(0))) + continue; + isRegionArgUsed = true; + isXBinopExpr = + innerOp.getNumOperands() > 0 && + innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0); + mlirExpr = + (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); + + break; + } + } - isXBinopExpr = innerUpdateOp->getOperand(0) == - atomicUpdateOp.getRegion().getArgument(0); + if (!isUpdateOpPresent) + return atomicUpdateOp.emitError("no atomic update operation found" + " inside atomic.update region"); - mlirExpr = (isXBinopExpr ? innerUpdateOp->getOperand(1) - : innerUpdateOp->getOperand(0)); + if (!isRegionArgUsed) + return atomicUpdateOp.emitError( + "the update operation inside the region " + "must have the region argument as an operand"); } llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir @@ -2,9 +2,7 @@ // Checking translation when the update is carried out by using more than one op // in the region. -llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %expr: i32) { - // expected-error @+2 {{exactly two operations are allowed inside an atomic update region while lowering to LLVM IR}} - // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.update}} +llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %expr: i32) { omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %t1 = llvm.mul %xval, %expr : i32 @@ -20,7 +18,7 @@ // Checking translation when the captured variable is not used in the inner // update operation llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %expr: i32) { - // expected-error @+2 {{the update operation inside the region must be a binary operation and that update operation must have the region argument as an operand}} + // expected-error @+2 {{the update operation inside the region must have the region argument as an operand}} // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.update}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): @@ -38,7 +36,7 @@ // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.capture}} omp.atomic.capture memory_order(seq_cst) { omp.atomic.read %v = %x : !llvm.ptr, i32 - // expected-error @+1 {{the update operation inside the region must be a binary operation and that update operation must have the region argument as an operand}} + // expected-error @+1 {{the update operation inside the region must have the region argument as an operand}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %newval = llvm.mul %expr, %expr : i32 @@ -52,11 +50,9 @@ // Checking translation when the captured variable is not used in the inner // update operation -llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %v: !llvm.ptr, %expr: i32) { - // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.capture}} +llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %v: !llvm.ptr, %expr: i32) { omp.atomic.capture memory_order(seq_cst) { omp.atomic.read %v = %x : !llvm.ptr, i32 - // expected-error @+1 {{exactly two operations are allowed inside an atomic update region while lowering to LLVM IR}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %t1 = llvm.mul %xval, %expr : i32