diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1127,27 +1127,28 @@ // Convert values and types. auto &innerOpList = opInst.getRegion().front().getOperations(); - if (innerOpList.size() != 2) - return opInst.emitError("exactly two operations are allowed inside an " - "atomic update region while lowering to LLVM IR"); - - Operation &innerUpdateOp = innerOpList.front(); - - if (innerUpdateOp.getNumOperands() != 2 || - !llvm::is_contained(innerUpdateOp.getOperands(), - opInst.getRegion().getArgument(0))) - return opInst.emitError( - "the update operation inside the region must be a binary operation and " - "that update operation must have the region argument as an operand"); - - llvm::AtomicRMWInst::BinOp binop = convertBinOpToAtomic(innerUpdateOp); - - bool isXBinopExpr = - innerUpdateOp.getNumOperands() > 0 && - innerUpdateOp.getOperand(0) == opInst.getRegion().getArgument(0); + bool isRegionArgUsed{false}, isXBinopExpr{false}; + llvm::AtomicRMWInst::BinOp binop; + mlir::Value mlirExpr; + // Find the binary update operation that uses the region argument + // and get the expression to update + for (Operation &innerOp : innerOpList) { + if (innerOp.getNumOperands() == 2) { + binop = convertBinOpToAtomic(innerOp); + if (!llvm::is_contained(innerOp.getOperands(), + opInst.getRegion().getArgument(0))) + continue; + isRegionArgUsed = true; + isXBinopExpr = innerOp.getNumOperands() > 0 && + innerOp.getOperand(0) == opInst.getRegion().getArgument(0); + mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); + break; + } + } + if (!isRegionArgUsed) + return opInst.emitError("no atomic update operation with region argument" + " as operand found inside atomic.update region"); - mlir::Value mlirExpr = (isXBinopExpr ? innerUpdateOp.getOperand(1) - : innerUpdateOp.getOperand(0)); llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.getX()); llvm::Type *llvmXElementType = moduleTranslation.convertType( @@ -1210,25 +1211,28 @@ isPostfixUpdate = atomicCaptureOp.getSecondOp() == atomicCaptureOp.getAtomicUpdateOp().getOperation(); auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations(); - if (innerOpList.size() != 2) - return atomicUpdateOp.emitError( - "exactly two operations are allowed inside an " - "atomic update region while lowering to LLVM IR"); - Operation *innerUpdateOp = atomicUpdateOp.getFirstOp(); - if (innerUpdateOp->getNumOperands() != 2 || - !llvm::is_contained(innerUpdateOp->getOperands(), - atomicUpdateOp.getRegion().getArgument(0))) + bool isRegionArgUsed{false}; + // Find the binary update operation that uses the region argument + // and get the expression to update + for (Operation &innerOp : innerOpList) { + if (innerOp.getNumOperands() == 2) { + binop = convertBinOpToAtomic(innerOp); + if (!llvm::is_contained(innerOp.getOperands(), + atomicUpdateOp.getRegion().getArgument(0))) + continue; + isRegionArgUsed = true; + isXBinopExpr = + innerOp.getNumOperands() > 0 && + innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0); + mlirExpr = + (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0)); + break; + } + } + if (!isRegionArgUsed) return atomicUpdateOp.emitError( - "the update operation inside the region must be a binary operation " - "and that update operation must have the region argument as an " - "operand"); - binop = convertBinOpToAtomic(*innerUpdateOp); - - isXBinopExpr = innerUpdateOp->getOperand(0) == - atomicUpdateOp.getRegion().getArgument(0); - - mlirExpr = (isXBinopExpr ? innerUpdateOp->getOperand(1) - : innerUpdateOp->getOperand(0)); + "no atomic update operation with region argument" + " as operand found inside atomic.update region"); } llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir @@ -3,8 +3,6 @@ // Checking translation when the update is carried out by using more than one op // in the region. llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %expr: i32) { - // expected-error @+2 {{exactly two operations are allowed inside an atomic update region while lowering to LLVM IR}} - // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.update}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %t1 = llvm.mul %xval, %expr : i32 @@ -20,7 +18,7 @@ // Checking translation when the captured variable is not used in the inner // update operation llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %expr: i32) { - // expected-error @+2 {{the update operation inside the region must be a binary operation and that update operation must have the region argument as an operand}} + // expected-error @+2 {{no atomic update operation with region argument as operand found inside atomic.update region}} // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.update}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): @@ -38,7 +36,7 @@ // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.capture}} omp.atomic.capture memory_order(seq_cst) { omp.atomic.read %v = %x : !llvm.ptr, i32 - // expected-error @+1 {{the update operation inside the region must be a binary operation and that update operation must have the region argument as an operand}} + // expected-error @+1 {{no atomic update operation with region argument as operand found inside atomic.update region}} omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %newval = llvm.mul %expr, %expr : i32 @@ -53,10 +51,8 @@ // Checking translation when the captured variable is not used in the inner // update operation llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %v: !llvm.ptr, %expr: i32) { - // expected-error @+1 {{LLVM Translation failed for operation: omp.atomic.capture}} omp.atomic.capture memory_order(seq_cst) { - omp.atomic.read %v = %x : !llvm.ptr, i32 - // expected-error @+1 {{exactly two operations are allowed inside an atomic update region while lowering to LLVM IR}} + omp.atomic.read %v = %x : !llvm.ptr, i32 omp.atomic.update %x : !llvm.ptr { ^bb0(%xval: i32): %t1 = llvm.mul %xval, %expr : i32