Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -647,13 +647,21 @@ } def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">; +def GPU_ShuffleOpDown : StrEnumAttrCase<"DOWN", -1, "down">; +def GPU_ShuffleOpUp : StrEnumAttrCase<"UP", -1, "up">; +def GPU_ShuffleOpIdx : StrEnumAttrCase<"IDX", -1, "idx">; def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", "Indexing modes supported by gpu.shuffle.", [ - GPU_ShuffleOpXor, + GPU_ShuffleOpXor, GPU_ShuffleOpUp, GPU_ShuffleOpDown, GPU_ShuffleOpIdx, ]>{ let cppNamespace = "::mlir::gpu"; + let storageType = "mlir::StringAttr"; + let returnType = "::mlir::gpu::ShuffleModeAttr"; + let convertFromStorage = + "*symbolizeEnum<::mlir::gpu::ShuffleModeAttr>($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; } Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -97,22 +97,36 @@ let assemblyFormat = "attr-dict"; } -def NVVM_ShflBflyOp : - NVVM_Op<"shfl.sync.bfly">, +def ShflKindBfly : StrEnumAttrCase<"bfly">; +def ShflKindUp : StrEnumAttrCase<"up">; +def ShflKindDown : StrEnumAttrCase<"down">; +def ShflKindIdx : StrEnumAttrCase<"idx">; + +/// Enum attribute of the different shuffle kinds. +def ShflKind : StrEnumAttr<"ShflKind", "NVVM shuffle kind", + [ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> { + let cppNamespace = "::mlir::NVVM"; + let storageType = "mlir::StringAttr"; + let returnType = "NVVM::ShflKind"; + let convertFromStorage = "*symbolizeEnum($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; +} + +def NVVM_ShflOp : + NVVM_Op<"shfl.sync">, Results<(outs LLVM_Type:$res)>, - Arguments<(ins LLVM_Type:$dst, + Arguments<(ins I32:$dst, LLVM_Type:$val, - LLVM_Type:$offset, - LLVM_Type:$mask_and_clamp, + I32:$offset, + I32:$mask_and_clamp, + ShflKind:$kind, OptionalAttr:$return_value_and_is_valid)> { string llvmBuilder = [{ - auto intId = getShflBflyIntrinsicId( - $_resultType, static_cast($return_value_and_is_valid)); + auto intId = getShflIntrinsicId( + $_resultType, $kind, static_cast($return_value_and_is_valid)); $res = createIntrinsicCall(builder, intId, {$dst, $val, $offset, $mask_and_clamp}); }]; - let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; let verifier = [{ if (!(*this)->getAttrOfType("return_value_and_is_valid")) return success(); @@ -125,6 +139,10 @@ "i1 as the second element"); return success(); }]; + let assemblyFormat = [{ + $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict + `:` type($val) `->` type($res) + }]; } def NVVM_VoteBallotOp : Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -39,6 +39,21 @@ namespace { +/// Convert gpu dialect shfl mode enum to the equivalent nvvm one. +static NVVM::ShflKind convertShflKind(gpu::ShuffleModeAttr mode) { + switch (mode) { + case gpu::ShuffleModeAttr::XOR: + return NVVM::ShflKind::bfly; + case gpu::ShuffleModeAttr::UP: + return NVVM::ShflKind::up; + case gpu::ShuffleModeAttr::DOWN: + return NVVM::ShflKind::down; + case gpu::ShuffleModeAttr::IDX: + return NVVM::ShflKind::idx; + } + llvm_unreachable("unknown shuffle mode"); +} + struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -81,9 +96,9 @@ rewriter.create(loc, int32Type, adaptor.width(), one); auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - Value shfl = rewriter.create( + Value shfl = rewriter.create( loc, resultTy, activeMask, adaptor.value(), adaptor.offset(), - maskAndClamp, returnValueAndIsValidAttr); + maskAndClamp, convertShflKind(op.mode()), returnValueAndIsValidAttr); Value shflValue = rewriter.create( loc, valueTy, shfl, rewriter.getIndexArrayAttr(0)); Value isActiveSrcLane = rewriter.create( Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -302,7 +302,7 @@ } static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { - p << ' ' << op.getOperands() << ' ' << op.mode() << " : " + p << ' ' << op.getOperands() << ' ' << stringifyEnum(op.mode()) << " : " << op.value().getType(); } Index: mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -43,33 +43,6 @@ p << " : " << op->getResultTypes(); } -// ::= -// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask` -// ({return_value_and_is_valid})? : result_type -static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type resultType; - if (parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType) || - parser.addTypeToList(resultType, result.types)) - return failure(); - - for (auto &attr : result.attributes) { - if (attr.getName() != "return_value_and_is_valid") - continue; - auto structType = resultType.dyn_cast(); - if (structType && !structType.getBody().empty()) - resultType = structType.getBody()[0]; - break; - } - - auto int32Ty = IntegerType::get(parser.getContext(), 32); - return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty}, - parser.getNameLoc(), result.operands); -} - // ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, OperationState &result) { Index: mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -23,15 +23,45 @@ using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; -static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, - bool withPredicate) { +static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, + NVVM::ShflKind kind, + bool withPredicate) { + if (withPredicate) { resultType = cast(resultType)->getElementType(0); - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; + switch (kind) { + case NVVM::ShflKind::bfly: + return resultType->isFloatTy() + ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; + case NVVM::ShflKind::up: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p + : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; + case NVVM::ShflKind::down: + return resultType->isFloatTy() + ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p + : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; + case NVVM::ShflKind::idx: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p + : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; + } + } else { + switch (kind) { + case NVVM::ShflKind::bfly: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + case NVVM::ShflKind::up: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 + : llvm::Intrinsic::nvvm_shfl_sync_up_i32; + case NVVM::ShflKind::down: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 + : llvm::Intrinsic::nvvm_shfl_sync_down_i32; + case NVVM::ShflKind::idx: + return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 + : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; + } } - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + llvm_unreachable("unknown shuffle kind"); } namespace { Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -78,7 +78,7 @@ gpu.func @gpu_all_reduce_op() { %arg0 = arith.constant 1.0 : f32 // TODO: Check full IR expansion once lowering has settled. - // CHECK: nvvm.shfl.sync.bfly + // CHECK: nvvm.shfl.sync {{.*}} "bfly" // CHECK: nvvm.barrier0 // CHECK: llvm.fadd %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32) @@ -94,7 +94,7 @@ gpu.func @gpu_all_reduce_region() { %arg0 = arith.constant 1 : i32 // TODO: Check full IR expansion once lowering has settled. - // CHECK: nvvm.shfl.sync.bfly + // CHECK: nvvm.shfl.sync {{.*}} "bfly" // CHECK: nvvm.barrier0 %result = "gpu.all_reduce"(%arg0) ({ ^bb(%lhs : i32, %rhs : i32): @@ -109,7 +109,7 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_shuffle() - builtin.func @gpu_shuffle() -> (f32) { + builtin.func @gpu_shuffle() -> (f32, f32, f32, f32) { // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 %arg0 = arith.constant 1.0 : f32 // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32 @@ -120,12 +120,18 @@ // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : i32 // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : i32 // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32 - // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm.struct<(f32, i1)> + // CHECK: %[[#SHFL:]] = nvvm.shfl.sync %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] {kind = "bfly", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)> // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)> %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1) - - std.return %shfl : f32 + // CHECK: nvvm.shfl.sync {{.*}} {kind = "up", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + %shflu, %predu = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "up" } : (f32, i32, i32) -> (f32, i1) + // CHECK: nvvm.shfl.sync {{.*}} {kind = "down", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + %shfld, %predd = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "down" } : (f32, i32, i32) -> (f32, i1) + // CHECK: nvvm.shfl.sync {{.*}} {kind = "idx", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + %shfli, %predi = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "idx" } : (f32, i32, i32) -> (f32, i1) + + std.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32 } } Index: mlir/test/Dialect/GPU/ops.mlir =================================================================== --- mlir/test/Dialect/GPU/ops.mlir +++ mlir/test/Dialect/GPU/ops.mlir @@ -55,6 +55,12 @@ %offset = arith.constant 3 : i32 // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32 %shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32 + // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} up : f32 + %shfl1, %pred1 = gpu.shuffle %arg0, %offset, %width up : f32 + // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} down : f32 + %shfl2, %pred2 = gpu.shuffle %arg0, %offset, %width down : f32 + // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} idx : f32 + %shfl3, %pred3 = gpu.shuffle %arg0, %offset, %width idx : f32 "gpu.barrier"() : () -> () Index: mlir/test/Dialect/LLVMIR/invalid.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/invalid.mlir +++ mlir/test/Dialect/LLVMIR/invalid.mlir @@ -495,21 +495,21 @@ func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 + %0 = nvvm.shfl.sync %arg0, %arg3, %arg1, %arg2 {kind = "bfly", return_value_and_is_valid} : i32 -> i32 } // ----- func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32)> + %0 = nvvm.shfl.sync %arg0, %arg3, %arg1, %arg2 {kind = "bfly", return_value_and_is_valid} : i32 -> !llvm.struct<(i32)> } // ----- func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) { // expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}} - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i32)> + %0 = nvvm.shfl.sync %arg0, %arg3, %arg1, %arg2 {kind = "bfly", return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)> } // ----- Index: mlir/test/Dialect/LLVMIR/nvvm.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/nvvm.mlir +++ mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -37,20 +37,26 @@ func @nvvm_shfl( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> i32 { - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : i32 - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 - %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : f32 + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "bfly"} : i32 -> i32 + %0 = nvvm.shfl.sync %arg0, %arg3, %arg1, %arg2 {kind = "bfly"} : i32 -> i32 + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "bfly"} : f32 -> f32 + %1 = nvvm.shfl.sync %arg0, %arg4, %arg1, %arg2 {kind = "bfly"} : f32 -> f32 + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "up"} : f32 -> f32 + %2 = nvvm.shfl.sync %arg0, %arg4, %arg1, %arg2 {kind = "up"} : f32 -> f32 + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "down"} : f32 -> f32 + %3 = nvvm.shfl.sync %arg0, %arg4, %arg1, %arg2 {kind = "down"} : f32 -> f32 + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "idx"} : f32 -> f32 + %4 = nvvm.shfl.sync %arg0, %arg4, %arg1, %arg2 {kind = "idx"} : f32 -> f32 llvm.return %0 : i32 } func @nvvm_shfl_pred( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> { - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(i32, i1)> - %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)> - // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(f32, i1)> - %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)> + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "bfly", return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)> + %0 = nvvm.shfl.sync %arg0, %arg3, %arg1, %arg2 {kind = "bfly", return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)> + // CHECK: nvvm.shfl.sync %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {kind = "bfly", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + %1 = nvvm.shfl.sync %arg0, %arg4, %arg1, %arg2 {kind = "bfly", return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> llvm.return %0 : !llvm.struct<(i32, i1)> }