Index: clang/lib/CodeGen/CGBuiltin.cpp =================================================================== --- clang/lib/CodeGen/CGBuiltin.cpp +++ clang/lib/CodeGen/CGBuiltin.cpp @@ -10396,8 +10396,7 @@ llvm_unreachable("Unexpected builtin ID."); } Value *Result = - Builder.CreateCall(CGM.getIntrinsic(IID), - {Builder.CreatePointerCast(Src, VoidPtrTy), Ldm}); + Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm}); // Save returned values. for (unsigned i = 0; i < NumResults; ++i) { @@ -10436,10 +10435,9 @@ default: llvm_unreachable("Unexpected builtin ID."); } - Function *Intrinsic = CGM.getIntrinsic(IID); + Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType()); llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1); - SmallVector Values; - Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy)); + SmallVector Values = {Dst}; for (unsigned i = 0; i < NumResults; ++i) { Value *V = Builder.CreateAlignedLoad( Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)), Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -3884,30 +3884,22 @@ // // WMMA.LOAD -class NVVM_WMMA_LD_ALSTS +class NVVM_WMMA_LD_ALSTS : Intrinsic, NoCapture<0>], "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16" - #Space #!if(WithStride,".stride","") #"."#Type>; -multiclass NVVM_WMMA_LD_ALST { - def _stride: NVVM_WMMA_LD_ALSTS; - def NAME : NVVM_WMMA_LD_ALSTS; -} - -multiclass NVVM_WMMA_LD_ALT { - defm _global: NVVM_WMMA_LD_ALST; - defm _shared: NVVM_WMMA_LD_ALST; - defm NAME: NVVM_WMMA_LD_ALST; +multiclass NVVM_WMMA_LD_ALT { + def _stride: NVVM_WMMA_LD_ALSTS; + def NAME : NVVM_WMMA_LD_ALSTS; } multiclass NVVM_WMMA_LD_AT { @@ -3915,47 +3907,33 @@ defm _col: NVVM_WMMA_LD_ALT; } -// For some reason ReadOnly and NoCapture confuses tblgen if they are -// passed to Intrinsic<> form inside of a multiclass. Setting them globally -// outside of the multiclass works. -let IntrProperties = [IntrReadMem, IntrArgMemOnly, - ReadOnly<0>, NoCapture<0>] in { - defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; -} +defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; // WMMA.STORE.D -class NVVM_WMMA_STD_LSTSEmpty=[]> : Intrinsic<[], !listconcat( - [llvm_ptr_ty], + [llvm_anyptr_ty], !if(!eq(Type,"f16"), [regty, regty, regty, regty], [regty, regty, regty, regty, regty, regty, regty, regty]), !if(WithStride, [llvm_i32_ty], Empty)), - [], // Properties must be set during instantiation. + [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], "llvm.nvvm.wmma.store.d.sync."#Layout - #".m16n16k16"#Space + #".m16n16k16" #!if(WithStride,".stride","") #"."#Type>; -multiclass NVVM_WMMA_STD_LST { - def _stride: NVVM_WMMA_STD_LSTS; - def NAME: NVVM_WMMA_STD_LSTS; -} - multiclass NVVM_WMMA_STD_LT { - defm _global: NVVM_WMMA_STD_LST; - defm _shared: NVVM_WMMA_STD_LST; - defm NAME: NVVM_WMMA_STD_LST; + def _stride: NVVM_WMMA_STD_LSTS; + def NAME: NVVM_WMMA_STD_LSTS; } multiclass NVVM_WMMA_STD_T { @@ -3963,11 +3941,8 @@ defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>; } -let IntrProperties = [IntrWriteMem, IntrArgMemOnly, - WriteOnly<0>, NoCapture<0>] in { - defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; -} +defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; // WMMA.MMA class NVVM_WMMA_MMA_ABDCS : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma_load_" - # Abc - # "_" # Type - # "_" # Layout - # !subst(".","_",Space) - # !if(WithStride,"_stride", "")); + // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic + // for this function. + PatFrag IntrMatcher = !cast("INT_WMMA_LOAD_" + # !subst("a", "A", + !subst("b", "B", + !subst("c", "C_" # Type, Abc))) + # "_" # Layout + # !subst(".", "_", Space) + # !if(WithStride,"_stride", "") + # "_Intr"); dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); @@ -7451,7 +7454,7 @@ !subst(imem, ADDRvar, !subst(MEMri64, ADDRri64, !subst(MEMri, ADDRri, - !subst(ins, Intr, tmp))))); + !subst(ins, IntrMatcher, tmp))))); // Finally, consatenate both parts together. !con() requires both dags to have // the same operator, so we wrap PatArgs in a (set ...) dag. let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7466,20 +7469,53 @@ #";"; } -multiclass WMMA_LOAD_ALSTO { - def _stride: WMMA_LOAD_ALSTOS; - def NAME: WMMA_LOAD_ALSTOS; +class WMMA_LOAD_INTR_HELPER + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast("int_nvvm_wmma_load_" + # Abc + # "_" # Type + # "_" # Layout + # !if(WithStride,"_stride", "")); + code match_generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code match_shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code match_global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; + + int tmp; + let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); + let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp)); + let PredicateCode = !if(!eq(Space, ".shared"), match_shared, + !if(!eq(Space, ".global"), match_global, match_generic)); +} + +multiclass WMMA_LOAD_ALSTS { + def _avar: WMMA_LOAD_ALSTOS; + def _areg: WMMA_LOAD_ALSTOS; + def _areg64: WMMA_LOAD_ALSTOS; + def _ari: WMMA_LOAD_ALSTOS; + def _ari64: WMMA_LOAD_ALSTOS; +} + +multiclass WMMA_LOAD_ALSTSh { + // Define a PatFrag that matches appropriate intrinsic that loads from the + // given address space. + def _Intr : WMMA_LOAD_INTR_HELPER; + defm NAME: WMMA_LOAD_ALSTS; } multiclass WMMA_LOAD_ALST { - defm _avar: WMMA_LOAD_ALSTO; - defm _areg: WMMA_LOAD_ALSTO; - defm _areg64: WMMA_LOAD_ALSTO; - defm _ari: WMMA_LOAD_ALSTO; - defm _ari64: WMMA_LOAD_ALSTO; + string Type, NVPTXRegClass regclass> { + defm _stride: WMMA_LOAD_ALSTSh; + defm NAME: WMMA_LOAD_ALSTSh; } multiclass WMMA_LOAD_ALT + bit WithStride, DAGOperand DstOp> : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - Intrinsic Intr = !cast("int_nvvm_wmma_store_d_" - # Type - # "_" # Layout - # !subst(".","_",Space) - # !if(WithStride,"_stride", "")); + PatFrag IntrMatcher = !cast("INT_WMMA_STORE_D" + # "_" # Type + # "_" # Layout + # !subst(".", "_", Space) + # !if(WithStride,"_stride", "") + # "_Intr"); dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); @@ -7525,7 +7562,7 @@ !subst(imem, ADDRvar, !subst(MEMri64, ADDRri64, !subst(MEMri, ADDRri, - !subst(ins, Intr, tmp))))); + !subst(ins, IntrMatcher, tmp))))); let Pattern = [PatArgs]; let OutOperandList = (outs); let InOperandList = Ins; @@ -7543,20 +7580,57 @@ } -multiclass WMMA_STORE_D_LSTO { - def _stride: WMMA_STORE_D_LSTOS; - def NAME: WMMA_STORE_D_LSTOS; +class WMMA_STORE_INTR_HELPER + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast("int_nvvm_wmma_store_d" + # "_" # Type + # "_" # Layout + # !if(WithStride, "_stride", "")); + code match_generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code match_shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code match_global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; + + int tmp; + dag Args = !if(!eq(Type,"f16"), + (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), + (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, + node:$r4, node:$r5, node:$r6, node:$r7)); + dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); + let Operands = !con(Args, StrideArg); + let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp)); + let PredicateCode = !if(!eq(Space, ".shared"), match_shared, + !if(!eq(Space, ".global"), match_global, match_generic)); +} + +multiclass WMMA_STORE_D_LSTS { + def _avar: WMMA_STORE_D_LSTSO; + def _areg: WMMA_STORE_D_LSTSO; + def _areg64: WMMA_STORE_D_LSTSO; + def _ari: WMMA_STORE_D_LSTSO; + def _ari64: WMMA_STORE_D_LSTSO; +} + +multiclass WMMA_STORE_D_LSTSh { + // Define a PatFrag that matches appropriate intrinsic that loads from the + // given address space. + def _Intr: WMMA_STORE_INTR_HELPER; + defm NAME: WMMA_STORE_D_LSTS; } multiclass WMMA_STORE_D_LST { - defm _avar: WMMA_STORE_D_LSTO; - defm _areg: WMMA_STORE_D_LSTO; - defm _areg64: WMMA_STORE_D_LSTO; - defm _ari: WMMA_STORE_D_LSTO; - defm _ari64: WMMA_STORE_D_LSTO; + string Type, NVPTXRegClass regclass > { + defm _stride: WMMA_STORE_D_LSTSh; + defm NAME: WMMA_STORE_D_LSTSh; } multiclass WMMA_STORE_D_LT