diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -135,6 +135,45 @@ let assemblyFormat = "$arg attr-dict `:` type($res)"; } +//===----------------------------------------------------------------------===// +// NVVM redux op definitions +//===----------------------------------------------------------------------===// + +def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">; +def ReduxKindAdd : I32EnumAttrCase<"ADD", 1, "add">; +def ReduxKindAnd : I32EnumAttrCase<"AND", 2, "and">; +def ReduxKindMax : I32EnumAttrCase<"MAX", 3, "max">; +def ReduxKindMin : I32EnumAttrCase<"MIN", 4, "min">; +def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">; +def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">; +def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">; +def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">; + +/// Enum attribute of the different kinds. +def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", + [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, + ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def ReduxKindAttr : EnumAttr; + +def NVVM_ReduxOp : + NVVM_Op<"redux.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$val, + ReduxKindAttr:$kind, + I32:$mask_and_clamp)> { + string llvmBuilder = [{ + auto intId = getReduxIntrinsicId($_resultType, $kind); + $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); + }]; + let assemblyFormat = [{ + $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) + }]; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -25,6 +25,32 @@ using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; +static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, + NVVM::ReduxKind kind) { + if (!resultType->isIntegerTy(32)) + llvm_unreachable("unsupported data type for redux"); + + switch (kind) { + case NVVM::ReduxKind::ADD: + return llvm::Intrinsic::nvvm_redux_sync_add; + case NVVM::ReduxKind::UMAX: + return llvm::Intrinsic::nvvm_redux_sync_umax; + case NVVM::ReduxKind::UMIN: + return llvm::Intrinsic::nvvm_redux_sync_umin; + case NVVM::ReduxKind::AND: + return llvm::Intrinsic::nvvm_redux_sync_and; + case NVVM::ReduxKind::OR: + return llvm::Intrinsic::nvvm_redux_sync_or; + case NVVM::ReduxKind::XOR: + return llvm::Intrinsic::nvvm_redux_sync_xor; + case NVVM::ReduxKind::MAX: + return llvm::Intrinsic::nvvm_redux_sync_max; + case NVVM::ReduxKind::MIN: + return llvm::Intrinsic::nvvm_redux_sync_min; + } + llvm_unreachable("unknown redux kind"); +} + static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate) { diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -310,6 +310,29 @@ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } + +// CHECK-LABEL: llvm.func @redux_sync +llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { + // CHECK: nvvm.redux.sync add %{{.*}} + %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync max %{{.*}} + %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync min %{{.*}} + %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umax %{{.*}} + %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umin %{{.*}} + %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync and %{{.*}} + %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync or %{{.*}} + %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync xor %{{.*}} + %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32 + llvm.return %r1 : i32 +} + + // ----- // expected-error@below {{attribute attached to unexpected op}}