diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1396,6 +1396,11 @@ return NVT; } + virtual EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty, + bool AllowUnknown = false) const { + return getValueType(DL, Ty, AllowUnknown); + } + /// Return the EVT corresponding to this LLVM type. This is fixed by the LLVM /// operations except for the pointer size. If AllowUnknown is true, this /// will return MVT::Other for types with no EVT counterpart (e.g. structs), diff --git a/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp b/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp --- a/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/InlineAsmLowering.cpp @@ -325,7 +325,8 @@ return false; } - OpInfo.ConstraintVT = TLI->getValueType(DL, OpTy, true).getSimpleVT(); + OpInfo.ConstraintVT = + TLI->getAsmOperandValueType(DL, OpTy, true).getSimpleVT(); } else if (OpInfo.Type == InlineAsm::isOutput && !OpInfo.isIndirect) { assert(!Call.getType()->isVoidTy() && "Bad inline asm!"); @@ -334,13 +335,17 @@ TLI->getSimpleValueType(DL, STy->getElementType(ResNo)); } else { assert(ResNo == 0 && "Asm only has one result!"); - OpInfo.ConstraintVT = TLI->getSimpleValueType(DL, Call.getType()); + OpInfo.ConstraintVT = + TLI->getAsmOperandValueType(DL, Call.getType()).getSimpleVT(); } ++ResNo; } else { OpInfo.ConstraintVT = MVT::Other; } + if (OpInfo.ConstraintVT == MVT::i64x8) + return false; + // Compute the constraint code and ConstraintType to use. computeConstraintToUse(TLI, OpInfo); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8176,7 +8176,7 @@ } } - return TLI.getValueType(DL, OpTy, true); + return TLI.getAsmOperandValueType(DL, OpTy, true); } }; @@ -8479,8 +8479,8 @@ DAG.getDataLayout(), STy->getElementType(ResNo)); } else { assert(ResNo == 0 && "Asm only has one result!"); - OpInfo.ConstraintVT = - TLI.getSimpleValueType(DAG.getDataLayout(), Call.getType()); + OpInfo.ConstraintVT = TLI.getAsmOperandValueType( + DAG.getDataLayout(), Call.getType()).getSimpleVT(); } ++ResNo; } else { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4687,7 +4687,8 @@ getSimpleValueType(DL, STy->getElementType(ResNo)); } else { assert(ResNo == 0 && "Asm only has one result!"); - OpInfo.ConstraintVT = getSimpleValueType(DL, Call.getType()); + OpInfo.ConstraintVT = + getAsmOperandValueType(DL, Call.getType()).getSimpleVT(); } ++ResNo; break; diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -653,6 +653,9 @@ case 'x': Reg = getXRegFromWReg(Reg); break; + case 't': + Reg = getXRegFromXRegTuple(Reg); + break; } O << AArch64InstPrinter::getRegisterName(Reg); @@ -749,6 +752,10 @@ AArch64::GPR64allRegClass.contains(Reg)) return printAsmMRegister(MO, 'x', O); + // If this is an x register tuple, print an x register. + if (AArch64::GPR64x8ClassRegClass.contains(Reg)) + return printAsmMRegister(MO, 't', O); + unsigned AltName = AArch64::NoRegAltName; const TargetRegisterClass *RegClass; if (AArch64::ZPRRegClass.contains(Reg)) { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -330,6 +330,10 @@ // Cast between vectors of the same element type but differ in length. REINTERPRET_CAST, + // Nodes to build an LD64B / ST64B 64-bit quantity out of i64, and vice versa + LS64_BUILD, + LS64_EXTRACT, + LD1_MERGE_ZERO, LD1S_MERGE_ZERO, LDNF1_MERGE_ZERO, @@ -824,6 +828,9 @@ bool isAllActivePredicate(SDValue N) const; EVT getPromotedVTForPredicate(EVT VT) const; + EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty, + bool AllowUnknown = false) const override; + private: /// Keep a pointer to the AArch64Subtarget around so that we can /// make the right decision when generating code for different targets. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -246,6 +246,12 @@ addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass); addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass); + if (Subtarget->hasLS64()) { + addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass); + setOperationAction(ISD::LOAD, MVT::i64x8, Custom); + setOperationAction(ISD::STORE, MVT::i64x8, Custom); + } + if (Subtarget->hasFPARMv8()) { addRegisterClass(MVT::f16, &AArch64::FPR16RegClass); addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass); @@ -2023,6 +2029,8 @@ MAKE_CASE(AArch64ISD::LASTA) MAKE_CASE(AArch64ISD::LASTB) MAKE_CASE(AArch64ISD::REINTERPRET_CAST) + MAKE_CASE(AArch64ISD::LS64_BUILD) + MAKE_CASE(AArch64ISD::LS64_EXTRACT) MAKE_CASE(AArch64ISD::TBL) MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) @@ -4611,17 +4619,51 @@ {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, StoreNode->getMemoryVT(), StoreNode->getMemOperand()); return Result; + } else if (MemVT == MVT::i64x8) { + SDValue Value = StoreNode->getValue(); + assert(Value->getValueType(0) == MVT::i64x8); + SDValue Chain = StoreNode->getChain(); + SDValue Base = StoreNode->getBasePtr(); + EVT PtrVT = Base.getValueType(); + for (unsigned i = 0; i < 8; i++) { + SDValue Part = DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64, + Value, DAG.getConstant(i, Dl, MVT::i32)); + SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base, + DAG.getConstant(i * 8, Dl, PtrVT)); + Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(), + StoreNode->getOriginalAlign()); + } + return Chain; } return SDValue(); } -// Custom lowering for extending v4i8 vector loads. SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); LoadSDNode *LoadNode = cast(Op); assert(LoadNode && "Expected custom lowering of a load node"); + + if (LoadNode->getMemoryVT() == MVT::i64x8) { + SmallVector Ops; + SDValue Base = LoadNode->getBasePtr(); + SDValue Chain = LoadNode->getChain(); + EVT PtrVT = Base.getValueType(); + for (unsigned i = 0; i < 8; i++) { + SDValue Ptr = DAG.getNode(ISD::ADD, DL, PtrVT, Base, + DAG.getConstant(i * 8, DL, PtrVT)); + SDValue Part = DAG.getLoad(MVT::i64, DL, Chain, Ptr, + LoadNode->getPointerInfo(), + LoadNode->getOriginalAlign()); + Ops.push_back(Part); + Chain = SDValue(Part.getNode(), 1); + } + SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, DL, MVT::i64x8, Ops); + return DAG.getMergeValues({Loaded, Chain}, DL); + } + + // Custom lowering for extending v4i8 vector loads. EVT VT = Op->getValueType(0); assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32"); @@ -8179,6 +8221,8 @@ case 'r': if (VT.isScalableVector()) return std::make_pair(0U, nullptr); + if (Subtarget->hasLS64() && VT.getSizeInBits() == 512) + return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass); if (VT.getFixedSizeInBits() == 64) return std::make_pair(0U, &AArch64::GPR64commonRegClass); return std::make_pair(0U, &AArch64::GPR32commonRegClass); @@ -8266,6 +8310,15 @@ return Res; } +EVT AArch64TargetLowering::getAsmOperandValueType(const DataLayout &DL, + llvm::Type *Ty, + bool AllowUnknown) const { + if (Subtarget->hasLS64() && Ty->isIntegerTy(512)) + return EVT(MVT::i64x8); + + return TargetLowering::getAsmOperandValueType(DL, Ty, AllowUnknown); +} + /// LowerAsmOperandForConstraint - Lower the specified operand into the Ops /// vector. If it is invalid, don't add anything to Ops. void AArch64TargetLowering::LowerAsmOperandForConstraint( diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -8104,6 +8104,20 @@ // FIXME: add SVE dot-product patterns. } +// Custom DAG nodes and isel rules to make a 64-byte block out of eight GPRs, +// so that it can be used as input to inline asm, and vice versa. +def LS64_BUILD : SDNode<"AArch64ISD::LS64_BUILD", SDTypeProfile<1, 8, []>>; +def LS64_EXTRACT : SDNode<"AArch64ISD::LS64_EXTRACT", SDTypeProfile<1, 2, []>>; +def : Pat<(i64x8 (LS64_BUILD GPR64:$x0, GPR64:$x1, GPR64:$x2, GPR64:$x3, + GPR64:$x4, GPR64:$x5, GPR64:$x6, GPR64:$x7)), + (REG_SEQUENCE GPR64x8Class, + $x0, x8sub_0, $x1, x8sub_1, $x2, x8sub_2, $x3, x8sub_3, + $x4, x8sub_4, $x5, x8sub_5, $x6, x8sub_6, $x7, x8sub_7)>; +foreach i = 0-7 in { + def : Pat<(i64 (LS64_EXTRACT (i64x8 GPR64x8:$val), (i32 i))), + (EXTRACT_SUBREG $val, !cast("x8sub_"#i))>; +} + let Predicates = [HasLS64] in { def LD64B: LoadStore64B<0b101, "ld64b", (ins GPR64sp:$Rn), (outs GPR64x8:$Rt)>; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -732,7 +732,9 @@ !foreach(i, [0,1,2,3,4,5,6,7], !cast("x8sub_"#i)), !foreach(i, [0,1,2,3,4,5,6,7], (trunc (decimate (rotl GPR64, i), 2), 12))>; -def GPR64x8Class : RegisterClass<"AArch64", [i64], 64, (trunc Tuples8X, 12)>; +def GPR64x8Class : RegisterClass<"AArch64", [i64x8], 512, (trunc Tuples8X, 12)> { + let Size = 512; +} def GPR64x8AsmOp : AsmOperandClass { let Name = "GPR64x8"; let ParserMethod = "tryParseGPR64x8"; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -106,6 +106,25 @@ return Reg; } +inline static unsigned getXRegFromXRegTuple(unsigned RegTuple) { + switch (RegTuple) { + case AArch64::X0_X1_X2_X3_X4_X5_X6_X7: return AArch64::X0; + case AArch64::X2_X3_X4_X5_X6_X7_X8_X9: return AArch64::X2; + case AArch64::X4_X5_X6_X7_X8_X9_X10_X11: return AArch64::X4; + case AArch64::X6_X7_X8_X9_X10_X11_X12_X13: return AArch64::X6; + case AArch64::X8_X9_X10_X11_X12_X13_X14_X15: return AArch64::X8; + case AArch64::X10_X11_X12_X13_X14_X15_X16_X17: return AArch64::X10; + case AArch64::X12_X13_X14_X15_X16_X17_X18_X19: return AArch64::X12; + case AArch64::X14_X15_X16_X17_X18_X19_X20_X21: return AArch64::X14; + case AArch64::X16_X17_X18_X19_X20_X21_X22_X23: return AArch64::X16; + case AArch64::X18_X19_X20_X21_X22_X23_X24_X25: return AArch64::X18; + case AArch64::X20_X21_X22_X23_X24_X25_X26_X27: return AArch64::X20; + case AArch64::X22_X23_X24_X25_X26_X27_X28_FP: return AArch64::X22; + } + // For anything else, return it unchanged. + return RegTuple; +} + static inline unsigned getBRegFromDReg(unsigned Reg) { switch (Reg) { case AArch64::D0: return AArch64::B0; diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/arm64-fallback.ll b/llvm/test/CodeGen/AArch64/GlobalISel/arm64-fallback.ll --- a/llvm/test/CodeGen/AArch64/GlobalISel/arm64-fallback.ll +++ b/llvm/test/CodeGen/AArch64/GlobalISel/arm64-fallback.ll @@ -126,7 +126,32 @@ ret void } +%struct.foo = type { [8 x i64] } + +; FALLBACK-WITH-REPORT-ERR: remark: :0:0: unable to translate instruction:{{.*}}ld64b{{.*}}asm_output_ls64 +; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_output_ls64 +; FALLBACK-WITH-REPORT-OUT-LABEL: asm_output_ls64 +define void @asm_output_ls64(%struct.foo* %output, i8* %addr) #2 { +entry: + %val = call i512 asm sideeffect "ld64b $0,[$1]", "=r,r,~{memory}"(i8* %addr) + %outcast = bitcast %struct.foo* %output to i512* + store i512 %val, i512* %outcast, align 8 + ret void +} + +; FALLBACK-WITH-REPORT-ERR: remark: :0:0: unable to translate instruction:{{.*}}st64b{{.*}}asm_input_ls64 +; FALLBACK-WITH-REPORT-ERR: warning: Instruction selection used fallback path for asm_input_ls64 +; FALLBACK-WITH-REPORT-OUT-LABEL: asm_input_ls64 +define void @asm_input_ls64(%struct.foo* %input, i8* %addr) #2 { +entry: + %incast = bitcast %struct.foo* %input to i512* + %val = load i512, i512* %incast, align 8 + call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %val, i8* %addr) + ret void +} + attributes #1 = { "target-features"="+sve" } +attributes #2 = { "target-features"="+ls64" } declare @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern) declare @llvm.aarch64.sve.ld1.nxv16i8(, i8*) diff --git a/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll b/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64 -mattr=+ls64 -verify-machineinstrs -o - %s | FileCheck %s + +%struct.foo = type { [8 x i64] } + +define void @load(%struct.foo* %output, i8* %addr) { +; CHECK-LABEL: load: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: //APP +; CHECK-NEXT: ld64b x2, [x1] +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: stp x8, x9, [x0, #48] +; CHECK-NEXT: stp x6, x7, [x0, #32] +; CHECK-NEXT: stp x4, x5, [x0, #16] +; CHECK-NEXT: stp x2, x3, [x0] +; CHECK-NEXT: ret +entry: + %val = call i512 asm sideeffect "ld64b $0,[$1]", "=r,r,~{memory}"(i8* %addr) + %outcast = bitcast %struct.foo* %output to i512* + store i512 %val, i512* %outcast, align 8 + ret void +} + +define void @store(%struct.foo* %input, i8* %addr) { +; CHECK-LABEL: store: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp x8, x9, [x0, #48] +; CHECK-NEXT: ldp x6, x7, [x0, #32] +; CHECK-NEXT: ldp x4, x5, [x0, #16] +; CHECK-NEXT: ldp x2, x3, [x0] +; CHECK-NEXT: //APP +; CHECK-NEXT: st64b x2, [x1] +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: ret +entry: + %incast = bitcast %struct.foo* %input to i512* + %val = load i512, i512* %incast, align 8 + call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %val, i8* %addr) + ret void +} + +define void @store2(i32* %in, i8* %addr) { +; CHECK-LABEL: store2: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sub sp, sp, #64 // =64 +; CHECK-NEXT: .cfi_def_cfa_offset 64 +; CHECK-NEXT: ldpsw x2, x3, [x0] +; CHECK-NEXT: ldrsw x4, [x0, #16] +; CHECK-NEXT: ldrsw x5, [x0, #64] +; CHECK-NEXT: ldrsw x6, [x0, #100] +; CHECK-NEXT: ldrsw x7, [x0, #144] +; CHECK-NEXT: ldrsw x8, [x0, #196] +; CHECK-NEXT: ldrsw x9, [x0, #256] +; CHECK-NEXT: //APP +; CHECK-NEXT: st64b x2, [x1] +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: add sp, sp, #64 // =64 +; CHECK-NEXT: ret +entry: + %0 = load i32, i32* %in, align 4 + %conv = sext i32 %0 to i64 + %arrayidx1 = getelementptr inbounds i32, i32* %in, i64 1 + %1 = load i32, i32* %arrayidx1, align 4 + %conv2 = sext i32 %1 to i64 + %arrayidx4 = getelementptr inbounds i32, i32* %in, i64 4 + %2 = load i32, i32* %arrayidx4, align 4 + %conv5 = sext i32 %2 to i64 + %arrayidx7 = getelementptr inbounds i32, i32* %in, i64 16 + %3 = load i32, i32* %arrayidx7, align 4 + %conv8 = sext i32 %3 to i64 + %arrayidx10 = getelementptr inbounds i32, i32* %in, i64 25 + %4 = load i32, i32* %arrayidx10, align 4 + %conv11 = sext i32 %4 to i64 + %arrayidx13 = getelementptr inbounds i32, i32* %in, i64 36 + %5 = load i32, i32* %arrayidx13, align 4 + %conv14 = sext i32 %5 to i64 + %arrayidx16 = getelementptr inbounds i32, i32* %in, i64 49 + %6 = load i32, i32* %arrayidx16, align 4 + %conv17 = sext i32 %6 to i64 + %arrayidx19 = getelementptr inbounds i32, i32* %in, i64 64 + %7 = load i32, i32* %arrayidx19, align 4 + %conv20 = sext i32 %7 to i64 + %s.sroa.10.0.insert.ext = zext i64 %conv20 to i512 + %s.sroa.10.0.insert.shift = shl nuw i512 %s.sroa.10.0.insert.ext, 448 + %s.sroa.9.0.insert.ext = zext i64 %conv17 to i512 + %s.sroa.9.0.insert.shift = shl nuw nsw i512 %s.sroa.9.0.insert.ext, 384 + %s.sroa.9.0.insert.insert = or i512 %s.sroa.10.0.insert.shift, %s.sroa.9.0.insert.shift + %s.sroa.8.0.insert.ext = zext i64 %conv14 to i512 + %s.sroa.8.0.insert.shift = shl nuw nsw i512 %s.sroa.8.0.insert.ext, 320 + %s.sroa.8.0.insert.insert = or i512 %s.sroa.9.0.insert.insert, %s.sroa.8.0.insert.shift + %s.sroa.7.0.insert.ext = zext i64 %conv11 to i512 + %s.sroa.7.0.insert.shift = shl nuw nsw i512 %s.sroa.7.0.insert.ext, 256 + %s.sroa.7.0.insert.insert = or i512 %s.sroa.8.0.insert.insert, %s.sroa.7.0.insert.shift + %s.sroa.6.0.insert.ext = zext i64 %conv8 to i512 + %s.sroa.6.0.insert.shift = shl nuw nsw i512 %s.sroa.6.0.insert.ext, 192 + %s.sroa.6.0.insert.insert = or i512 %s.sroa.7.0.insert.insert, %s.sroa.6.0.insert.shift + %s.sroa.5.0.insert.ext = zext i64 %conv5 to i512 + %s.sroa.5.0.insert.shift = shl nuw nsw i512 %s.sroa.5.0.insert.ext, 128 + %s.sroa.4.0.insert.ext = zext i64 %conv2 to i512 + %s.sroa.4.0.insert.shift = shl nuw nsw i512 %s.sroa.4.0.insert.ext, 64 + %s.sroa.4.0.insert.mask = or i512 %s.sroa.6.0.insert.insert, %s.sroa.5.0.insert.shift + %s.sroa.0.0.insert.ext = zext i64 %conv to i512 + %s.sroa.0.0.insert.mask = or i512 %s.sroa.4.0.insert.mask, %s.sroa.4.0.insert.shift + %s.sroa.0.0.insert.insert = or i512 %s.sroa.0.0.insert.mask, %s.sroa.0.0.insert.ext + call void asm sideeffect "st64b $0,[$1]", "r,r,~{memory}"(i512 %s.sroa.0.0.insert.insert, i8* %addr) + ret void +}