diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -102,4 +102,34 @@ ]; } +// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing +// the layouts of the operands supported by the ops that use this attribute. +def RowMajor: StrEnumAttrCase<"RowMajor", 0>; +def ColMajor: StrEnumAttrCase<"ColMajor", 1>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the layout of respective operands GPU Dialect. The layout +// later governs the lowerings to appropriate intrinsics. +def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", + [RowMajor, ColMajor]> { + let stringToSymbolFnName = "LayoutStrToEnum"; + let symbolToStringFnName = "EnumToLayoutStr"; +} + +// Cases of the String enum Attribute for SubgroupMmaOperand, representing +// the operand tags supported by the ops that use this attribute. +def DOp: StrEnumAttrCase<"DOp", 0>; +def Aop: StrEnumAttrCase<"AOp", 1>; +def Bop: StrEnumAttrCase<"BOp", 2>; +def Cop: StrEnumAttrCase<"COp", 3>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the operands for the operations. The operands later +// govern the lowerings to appropriate intrinsics. +def SubgroupMmaOperand: StrEnumAttr<"WmmaOp", "Specifes which subgroup operand is this", + [DOp, Aop, Bop, Cop]> { + let stringToSymbolFnName = "OpStrToEnum"; + let symbolToStringFnName = "EnumToOpStr"; +} + #endif // GPU_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -914,4 +914,130 @@ let verifier = [{ return ::verify(*this); }]; } +def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", + [MemoryEffects<[MemRead, MemWrite]>]>{ + + let summary = "GPU warp synchronous matrix load"; + + let description = [{ + The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively + using all the threads in a subgroup. + + This operation takes two memrefs as arguments. One is the source matrix + from which data is to be loaded and the other is the destination. The + source memref can be in the global or shared memory space. The destination + matrix is required to be in the thread private memory space. The starting of + the load address is determined using indices provided. The destination memref + has a restriction that its rank cannot be greater than `1`. The size of the + matrix being loaded depends on the `operand` attribute. `ldm` attribute specifies + the leading dimension of the source matrix. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + gpu.subgroup_mma_load_matrix src[%i,%j], %A[%c1] : {operand = "AOp", ldm = 32 + : ui16} : memref<32x32xf16, 3>, memref<1xvector<16xf16>, 5> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Arg:$dstMemref, + Variadic:$indices, + I64:$dstIndex, + I64Attr:$ldm, + SubgroupMmaOperand:$operand); + + let assemblyFormat = [{ + $srcMemref`[`$indices`]``,` $dstMemref`[`$dstIndex`]` attr-dict `:` type($srcMemref)`,` type($dstMemref) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", [MemoryEffects<[MemRead, MemWrite]>]>{ + + let summary = "GPU warp synchronous matrix store"; + + let description = [{ + The `gpu.subgroup_mma_store_matrix` operation loads a matrix collectively + using all the threads in a subgroup. + + This operation takes two memrefs as arguments. One is the source matrix which + contains the data to be stored and the other is the destination. The source + memref required to be in the thread private memory space. The destination + can be in the global or shared memory space. The starting of store address is + determined using indices provided. The source memref has a restriction that its + rank cannot be greater than `1`. The ldm` attribute specifies the leading dimension + of the destination matrix. + + This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + gpu.subgroup_mma_store_matrix %D[%c0], %sg[%i,%j] : {ldm = 32 : ui16} : + memref<1xvector<8xf16>, 5>, memref<32x32xf16, 3> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Arg:$dstMemref, + Variadic:$indices, + I64:$srcIndex, + I64Attr:$ldm); + + let assemblyFormat = [{ + $srcMemref`[`$srcIndex`]``,` $dstMemref`[`$indices`]` attr-dict `:` type($srcMemref)`,` type($dstMemref) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", [MemoryEffects<[MemRead, MemWrite]>]>{ + + let summary = "GPU warp synchronous matrix multiply accumulate"; + + let description = [{ + The `gpu.subgroup_mma_store_matrix` operation performs a matrix-multiply + accumulate(mma) operation using all the threads in a subgroup. + + This operation takes four memrefs as arguments. All four memrefs hold `A`, `B`, `C` + and `D` operands for the mma operation. The operation performed is represented + as `D = A * B + C`. All the memrefs are required to be the thread private memory + space. The memrefs have a restriction that they can have a rank of atmost `1`. + Each memref is to be supplied with a index determining the element which is to be used + for computation. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_load_matrix`. + + Example: + + ```mlir + gpu.subgroup_mma_store_matrix %A[%c0], %B[%c0], %C[%c0], %D[%c0] : + memref<1xvector<16xf16>, 5>, memref<1xvector<16xf16>, 5>, + memref<1xvector<8xf16>, 5>, memref<1xvector<8xf16>, 5> + ``` + }]; + + let arguments = (ins Arg:$opA, + Arg:$opB, + Arg:$opC, + Arg:$opD, + I64:$AIndex, + I64:$BIndex, + I64:$CIndex, + I64:$DIndex); + + let assemblyFormat = [{ + $opA`[`$AIndex`]``,` $opB`[`$BIndex`]``,` $opC`[`$CIndex`]``,` $opD`[`$DIndex`]` attr-dict `:` type($opA)`,` type($opB)`,` type($opC)`,` type($opD) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // GPU_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/GPU/GPUBase.td" //===----------------------------------------------------------------------===// // NVVM dialect definitions //===----------------------------------------------------------------------===// @@ -144,4 +145,129 @@ let verifier = [{ return ::verify(*this); }]; } +def NVVM_WMMALoadOp : + NVVM_Op<"wmma.m16n16k16.load">, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args, + SubgroupMmaOperand:$operand)> { + string llvmBuilder = [{ + if($operand.compare("AOp") == 0) + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args); + + else if($operand.compare("BOp") == 0) + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args); + + else if($operand.compare("COp") == 0) + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args); + }]; + + let summary = "Warp synchronous matrix load"; + + let description = [{ + The `nvvm.wmma.m16n16k16.load` operation loads a matrix collectively using all the threads + in a warp. + + The operation takes two arguments, the address from where the matrix elements are + to be loaded from and a stride. The stride argument represents the leading dimension + of the source matrix. The address and the stride are required to be the same across + all threads in the warp. Each thread in a warp holds a certain number of elements. + The Op returns a LLVM struct which holds the elements of the matrix held by this thread. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.store` and `nvvm.wmma.m16n16k16.mma`. + + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let assemblyFormat = "$args attr-dict `:` type($args) `->` type($res)"; + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAStoreOp : + NVVM_Op<"wmma.m16n16k16.store">, + Arguments<(ins Variadic:$args)>{ + string llvmBuilder = [{ + createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args); + }]; + + let summary = "Warp synchronous matrix store"; + + let description = [{ + The `nvvm.wmma.m16n16k16.store` operation stores a matrix collectively using all the threads + in a warp. + + The operation takes as arguments the address to where the matrix elements are + to be stored, a stride and the elements to store, held by the current thread. + The stride argument represents the leading dimension of the destination matrix. + The address and the stride are required to be the same across all threads in the + warp. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.m16n16k16.mma`. + + Example: + + ```mlir + nvvm.wmma.m16n16k16.store %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10 : !llvm.ptr + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32 + ``` + }]; + + let assemblyFormat = "$args attr-dict `:` type($args)"; + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAMmaOp : + NVVM_Op<"wmma.m16n16k16.mma">, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args)>{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args); + }]; + + let summary = "Warp synchronous matrix-multiply accumulate using tensor cores."; + + let description = [{ + The `nvvm.wmma.m16n16k16.mma` operation performs a matrix-multiply accumulate(mma) + operation using all the threads in a warp. + + The operation performed is represented as `D = A * B + C`. The operation takes + as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the + current thread. The op returns a LLVM struct which holds a part of the result + held by the current thread. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma. + m16n16k16.store`. + + Example: + + ```mlir + %20 = nvvm.wmma.m16n16k16.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, + %13, %14, %15, %16, %17, %18, %19 : !llvm.vec<2 x half>, !llvm.vec<2 x half>, + !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, + !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, + !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, + !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, + !llvm.vec<2 x half>, !llvm.vec<2 x half> -> !llvm.struct<(vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let assemblyFormat = "$args attr-dict `:` type($args) `->` type($res)"; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -885,6 +885,138 @@ printer << "]"; } +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaLoadMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { + auto srcType = op.srcMemref().getType(); + auto dstType = op.dstMemref().getType(); + auto srcMemrefType = srcType.cast(); + auto dstMemrefType = dstType.cast(); + auto srcElemType = srcMemrefType.getElementType(); + auto dstElemType = dstMemrefType.getElementType(); + auto srcMemSpace = srcMemrefType.getMemorySpace(); + auto dstMemSpace = dstMemrefType.getMemorySpace(); + + if ((srcMemSpace != 0 && srcMemSpace != 3) || dstMemSpace != 5) + return op.emitError("source memorySpace of `0` or `3` and destination " + "memorySpace of `5` is only allowed"); + + if (!srcElemType.isF16()) + return op.emitOpError("operands only of type F16 allowed"); + + if (srcMemrefType.getRank() != 2 || !srcElemType.isF16()) + return op.emitError( + "source memref should be of rank 2 and have f16 elements."); + + if (auto dstVecTy = dstElemType.dyn_cast()) { + if (op.operand().equals("AOp") || op.operand().equals("BOp")) { + if (!dstVecTy.getElementType().isF16() || dstVecTy.getRank() != 1 || + dstVecTy.getDimSize(0) != 16 || dstMemrefType.getRank() != 1) + return op.emitError( + "destination for AOp and BOp should be memref>"); + } + if (op.operand().equals("COp")) { + if (!dstVecTy.getElementType().isF16() || dstVecTy.getRank() != 1 || + dstVecTy.getDimSize(0) != 8 || dstMemrefType.getRank() != 1) + return op.emitError( + "destination for COp should be of shape memref>"); + } + } else + return op.emitError( + "element type of Destination memref should be of vector type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaStoreMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { + auto srcType = op.srcMemref().getType(); + auto dstType = op.dstMemref().getType(); + auto srcMemrefType = srcType.cast(); + auto dstMemrefType = dstType.cast(); + auto srcElemType = srcMemrefType.getElementType(); + auto dstElemType = dstMemrefType.getElementType(); + auto srcMemSpace = srcMemrefType.getMemorySpace(); + auto dstMemSpace = dstMemrefType.getMemorySpace(); + + if ((dstMemSpace != 0 && dstMemSpace != 3) || srcMemSpace != 5) + return op.emitError("source memorySpace of `0` or `3` and destination " + "memorySpace of `5` is only allowed"); + + if (dstMemrefType.getRank() != 2 || !dstElemType.isF16()) + return op.emitError( + "source memref should be of rank 2 and have f16 elements."); + + if (auto srcVecTy = srcElemType.dyn_cast()) { + if (!srcVecTy.getElementType().isF16() || srcVecTy.getRank() != 1 || + srcVecTy.getDimSize(0) != 8 || srcMemrefType.getRank() != 1) + return op.emitError( + "source memref should be of shape memref>"); + } else + return op.emitError( + "source memref should be of shape memref>"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaComputeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaComputeOp op) { + enum operandMap { A, B, C, D }; + SmallVector opTypes; + SmallVector elemTypes; + SmallVector, 4> opShapes; + SmallVector opMemSpace; + + auto populateOpInfo = [&opTypes, &elemTypes, &opShapes, &opMemSpace, &op]() { + opTypes.push_back(op.opA().getType().cast()); + opTypes.push_back(op.opB().getType().cast()); + opTypes.push_back(op.opC().getType().cast()); + opTypes.push_back(op.opD().getType().cast()); + + for (MemRefType opType : opTypes) { + elemTypes.push_back(opType.getElementType()); + opShapes.push_back(opType.getShape()); + opMemSpace.push_back(opType.getMemorySpace()); + } + }; + populateOpInfo(); + + if (opMemSpace[A] != 5 || opMemSpace[B] != 5 || opMemSpace[C] != 5 || + opMemSpace[D] != 5) + return op.emitError("all operands should be in MemSpace 5"); + + auto verifyElems = [](Type &elemType, MemRefType &memrefType, + unsigned numElems) { + if (auto srcVecTy = elemType.dyn_cast()) { + if (!srcVecTy.getElementType().isF16() || srcVecTy.getRank() != 1 || + srcVecTy.getDimSize(0) != numElems || memrefType.getRank() != 1) { + return false; + } + } else { + return false; + } + return true; + }; + + if (!verifyElems(elemTypes[A], opTypes[A], 16) || + !verifyElems(elemTypes[B], opTypes[B], 16) || + !verifyElems(elemTypes[C], opTypes[C], 8) || + !verifyElems(elemTypes[D], opTypes[D], 8)) + return op.emitError( + "A and B must be of type memref>, C and " + "D must of type memref>"); + + return success(); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -131,6 +131,88 @@ return op.emitOpError("unimplemented mma.sync variant"); } +static LogicalResult verify(WMMALoadOp op) { + MLIRContext *context = op.getContext(); + auto I32Ty = LLVM::LLVMType::getInt32Ty(context); + auto I32Ptr1Ty = LLVM::LLVMPointerType::get(I32Ty, 1); + auto I32Ptr3Ty = LLVM::LLVMPointerType::get(I32Ty, 3); + auto I32Ptr0Ty = LLVM::LLVMPointerType::get(I32Ty, 0); + auto f16Ty = LLVM::LLVMType::getHalfTy(context); + auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); + auto f16x2x4StructTy = LLVM::LLVMType::getStructTy( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + auto f16x2x8StructTy = LLVM::LLVMType::getStructTy( + context, + {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + SmallVector operand_types(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operand_types != SmallVector{I32Ptr1Ty, I32Ty} && + operand_types != SmallVector{I32Ptr3Ty, I32Ty} && + operand_types != SmallVector{I32Ptr0Ty, I32Ty}) { + return op.emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3 followed by ldm of the source"); + } + + if (op.operand().equals("AOp") || op.operand().equals("BOp")) { + if (op.getType() != f16x2x8StructTy) { + return op.emitOpError("expected result type of loadAOp and loadBOp to be " + "a struct of 8 s"); + } + } else if (op.operand().equals("COp")) { + if (op.getType() != f16x2x4StructTy) { + return op.emitOpError( + "expected result type of loadCOp to be a struct of 4 s"); + } + } + + return success(); +} + +static LogicalResult verify(WMMAStoreOp op) { + MLIRContext *context = op.getContext(); + auto I32Ty = LLVM::LLVMType::getInt32Ty(context); + auto I32Ptr1Ty = LLVM::LLVMPointerType::get(I32Ty, 1); + auto I32Ptr3Ty = LLVM::LLVMPointerType::get(I32Ty, 3); + auto I32Ptr0Ty = LLVM::LLVMPointerType::get(I32Ty, 0); + auto f16Ty = LLVM::LLVMType::getHalfTy(context); + auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); + + SmallVector operand_types(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operand_types != SmallVector{I32Ptr1Ty, f16x2Ty, f16x2Ty, + f16x2Ty, f16x2Ty, I32Ty} && + operand_types != SmallVector{I32Ptr3Ty, f16x2Ty, f16x2Ty, + f16x2Ty, f16x2Ty, I32Ty} && + operand_types != SmallVector{I32Ptr0Ty, f16x2Ty, f16x2Ty, + f16x2Ty, f16x2Ty, I32Ty}) { + return op.emitOpError("expected operands to be a source pointer followed " + "by ldm of the source"); + } + + return success(); +} + +static LogicalResult verify(WMMAMmaOp op) { + MLIRContext *context = op.getContext(); + auto f16Ty = LLVM::LLVMType::getHalfTy(context); + auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); + auto f16x2x4StructTy = LLVM::LLVMType::getStructTy( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + SmallVector operand_types(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operand_types != SmallVector(20, f16x2Ty)) + return op.emitOpError("expected operands to be a source pointer followed " + "by ldm of the source"); + + if (op.getType() != f16x2x4StructTy) + return op.emitOpError( + "expected result type of AOp and BOp to be a struct of 8 s"); + + return success(); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -35,6 +35,27 @@ return builder.CreateCall(fn, args); } +static llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilder<> &builder, + llvm::Intrinsic::ID intrinsic, + ArrayRef args = {}) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn; + if (llvm::Intrinsic::isOverloaded(intrinsic)) { + if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16) { + // NVVM load and store instrinsic names are overloaded on the + // source/destination pointer type. Pointer is the first argument in the + // corresponding NVVM Op. + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, + {args[0]->getType()}); + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {}); + } + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + } + return builder.CreateCall(fn, args); +} + static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, bool withPredicate) { if (withPredicate) {