diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -26,13 +26,16 @@ def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; +def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; +def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; def AtomicRMWKindAttr : I64EnumAttr< "AtomicRMWKind", "", [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, - ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, + ATOMIC_RMW_KIND_ANDI]> { let cppNamespace = "::mlir"; } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -57,6 +57,8 @@ .Case([](arith::AddFOp) { return AtomicRMWKind::addf; }) .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; }) .Case([](arith::AddIOp) { return AtomicRMWKind::addi; }) + .Case([](arith::AndIOp) { return AtomicRMWKind::andi; }) + .Case([](arith::OrIOp) { return AtomicRMWKind::ori; }) .Case([](arith::MulIOp) { return AtomicRMWKind::muli; }) .Case([](arith::MinFOp) { return AtomicRMWKind::minf; }) .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; }) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -792,6 +792,10 @@ return LLVM::AtomicBinOp::min; case AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; + case AtomicRMWKind::ori: + return LLVM::AtomicBinOp::_or; + case AtomicRMWKind::andi: + return LLVM::AtomicBinOp::_and; default: return llvm::None; } @@ -809,8 +813,9 @@ if (failed(match(atomicOp))) return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); - if (!maybeKind) + if (!maybeKind) { return failure(); + } auto resultType = adaptor.getValue().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -155,6 +155,8 @@ case AtomicRMWKind::mins: case AtomicRMWKind::minu: case AtomicRMWKind::muli: + case AtomicRMWKind::ori: + case AtomicRMWKind::andi: if (!op.getValue().getType().isa()) return op.emitOpError() << "with kind '" << stringifyAtomicRMWKind(op.getKind()) @@ -178,7 +180,12 @@ case AtomicRMWKind::addf: case AtomicRMWKind::addi: case AtomicRMWKind::maxu: + case AtomicRMWKind::ori: return builder.getZeroAttr(resultType); + case AtomicRMWKind::andi: + return builder.getIntegerAttr( + resultType, + APInt::getAllOnes(resultType.cast().getWidth())); case AtomicRMWKind::maxs: return builder.getIntegerAttr( resultType, @@ -240,6 +247,10 @@ return builder.create(loc, lhs, rhs); case AtomicRMWKind::minu: return builder.create(loc, lhs, rhs); + case AtomicRMWKind::ori: + return builder.create(loc, lhs, rhs); + case AtomicRMWKind::andi: + return builder.create(loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -502,6 +502,10 @@ // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel + atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel + atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel return }