diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -631,6 +631,9 @@ case 'x': Reg = getXRegFromWReg(Reg); break; + case 't': + Reg = getXRegFromXRegTuple(Reg); + break; } O << AArch64InstPrinter::getRegisterName(Reg); @@ -726,6 +729,10 @@ AArch64::GPR64allRegClass.contains(Reg)) return printAsmMRegister(MO, 'x', O); + // If this is an x register tuple, print an x register. + if (AArch64::GPR64x8ClassRegClass.contains(Reg)) + return printAsmMRegister(MO, 't', O); + unsigned AltName = AArch64::NoRegAltName; const TargetRegisterClass *RegClass; if (AArch64::ZPRRegClass.contains(Reg)) { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -312,6 +312,10 @@ REINTERPRET_CAST, + // Nodes to build an LD64B / ST64B 64-bit quantity out of i64, and vice versa + LS64_BUILD, + LS64_EXTRACT, + LD1_MERGE_ZERO, LD1S_MERGE_ZERO, LDNF1_MERGE_ZERO, @@ -809,6 +813,7 @@ SmallVectorImpl &InVals, bool isThisReturn, SDValue ThisVal) const; + SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -222,6 +222,14 @@ addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass); addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass); + if (Subtarget->hasLS64()) { + addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass); + setOperationAction(ISD::BITCAST, MVT::i64x8, Legal); + setOperationAction(ISD::UNDEF, MVT::i64x8, Legal); + setOperationAction(ISD::LOAD, MVT::i64x8, Custom); + setOperationAction(ISD::STORE, MVT::i64x8, Custom); + } + if (Subtarget->hasFPARMv8()) { addRegisterClass(MVT::f16, &AArch64::FPR16RegClass); addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass); @@ -1817,6 +1825,8 @@ MAKE_CASE(AArch64ISD::LASTB) MAKE_CASE(AArch64ISD::REV) MAKE_CASE(AArch64ISD::REINTERPRET_CAST) + MAKE_CASE(AArch64ISD::LS64_BUILD) + MAKE_CASE(AArch64ISD::LS64_EXTRACT) MAKE_CASE(AArch64ISD::TBL) MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) @@ -4059,6 +4069,38 @@ ST->getBasePtr(), ST->getMemOperand()); } +SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, + SelectionDAG &DAG) const { + LoadSDNode *LoadNode = cast(Op); + assert (LoadNode && "Can only custom lower load nodes"); + + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerFixedLengthVectorLoadToSVE(Op, DAG); + + if (LoadNode->getMemoryVT() == MVT::i64x8) { + SDLoc Dl(Op); + SmallVector Ops; + SDValue Base = LoadNode->getBasePtr(); + SDValue Chain = LoadNode->getChain(); + EVT PtrVT = Base.getValueType(); + for (unsigned i = 0; i < 8; i++) { + SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base, + DAG.getConstant(i * 8, Dl, PtrVT)); + SDValue Part = DAG.getLoad(MVT::i64, Dl, Chain, Ptr, + LoadNode->getPointerInfo(), + LoadNode->getOriginalAlign()); + Ops.push_back(Part); + Chain = SDValue(Part.getNode(), 1); + } + SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, Dl, MVT::i64x8, Ops); + return DAG.getMergeValues({Loaded, Chain}, Dl); + } + + llvm_unreachable("Unexpected request to lower ISD::LOAD"); + + return SDValue(); +} + // Custom lowering for any store, vector or scalar and/or default or with // a truncate operations. Currently only custom lower truncate operation // from vector v4i16 to v4i8 or volatile stores of i128. @@ -4127,6 +4169,22 @@ {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, StoreNode->getMemoryVT(), StoreNode->getMemOperand()); return Result; + } else if (MemVT == MVT::i64x8) { + SDValue Value = StoreNode->getValue(); + assert(Value->getValueType(0) == MVT::i64x8); + SDValue Chain = StoreNode->getChain(); + SDValue Base = StoreNode->getBasePtr(); + EVT PtrVT = Base.getValueType(); + for (unsigned i = 0; i < 8; i++) { + SDValue Part = + DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64, + Value, DAG.getConstant(i, Dl, MVT::i32)); + SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base, + DAG.getConstant(i * 8, Dl, PtrVT)); + Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(), + StoreNode->getOriginalAlign()); + } + return Chain; } return SDValue(); @@ -4356,9 +4414,7 @@ case ISD::TRUNCATE: return LowerTRUNCATE(Op, DAG); case ISD::LOAD: - if (useSVEForFixedLengthVectorVT(Op.getValueType())) - return LowerFixedLengthVectorLoadToSVE(Op, DAG); - llvm_unreachable("Unexpected request to lower ISD::LOAD"); + return LowerLOAD(Op, DAG); case ISD::ADD: return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED); case ISD::AND: @@ -7591,6 +7647,8 @@ case 'r': if (VT.isScalableVector()) return std::make_pair(0U, nullptr); + if (Subtarget->hasLS64() && VT.getSizeInBits() == 512) + return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass); if (VT.getFixedSizeInBits() == 64) return std::make_pair(0U, &AArch64::GPR64commonRegClass); return std::make_pair(0U, &AArch64::GPR32commonRegClass); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -7751,6 +7751,20 @@ // FIXME: add SVE dot-product patterns. } +// Custom DAG nodes and isel rules to make a 64-byte block out of eight GPRs, +// so that it can be used as input to inline asm, and vice versa. +def LS64_BUILD : SDNode<"AArch64ISD::LS64_BUILD", SDTypeProfile<1, 8, []>>; +def LS64_EXTRACT : SDNode<"AArch64ISD::LS64_EXTRACT", SDTypeProfile<1, 2, []>>; +def : Pat<(i64x8 (LS64_BUILD GPR64:$x0, GPR64:$x1, GPR64:$x2, GPR64:$x3, + GPR64:$x4, GPR64:$x5, GPR64:$x6, GPR64:$x7)), + (REG_SEQUENCE GPR64x8Class, + $x0, x8sub_0, $x1, x8sub_1, $x2, x8sub_2, $x3, x8sub_3, + $x4, x8sub_4, $x5, x8sub_5, $x6, x8sub_6, $x7, x8sub_7)>; +foreach i = 0-7 in { + def : Pat<(i64 (LS64_EXTRACT (i64x8 GPR64x8:$val), (i32 i))), + (EXTRACT_SUBREG $val, !cast("x8sub_"#i))>; +} + let Predicates = [HasLS64] in { def LD64B: LoadStore64B<0b101, "ld64b", (ins GPR64sp:$Rn), (outs GPR64x8:$Rt)>; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -724,7 +724,9 @@ !foreach(i, [0,1,2,3,4,5,6,7], !cast("x8sub_"#i)), !foreach(i, [0,1,2,3,4,5,6,7], (trunc (decimate (rotl GPR64, i), 2), 12))>; -def GPR64x8Class : RegisterClass<"AArch64", [i64], 64, (trunc Tuples8X, 12)>; +def GPR64x8Class : RegisterClass<"AArch64", [i64x8], 512, (trunc Tuples8X, 12)> { + let Size = 512; +} def GPR64x8AsmOp : AsmOperandClass { let Name = "GPR64x8"; let ParserMethod = "tryParseGPR64x8"; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -106,6 +106,25 @@ return Reg; } +inline static unsigned getXRegFromXRegTuple(unsigned RegTuple) { + switch (RegTuple) { + case AArch64::X0_X1_X2_X3_X4_X5_X6_X7: return AArch64::X0; + case AArch64::X2_X3_X4_X5_X6_X7_X8_X9: return AArch64::X2; + case AArch64::X4_X5_X6_X7_X8_X9_X10_X11: return AArch64::X4; + case AArch64::X6_X7_X8_X9_X10_X11_X12_X13: return AArch64::X6; + case AArch64::X8_X9_X10_X11_X12_X13_X14_X15: return AArch64::X8; + case AArch64::X10_X11_X12_X13_X14_X15_X16_X17: return AArch64::X10; + case AArch64::X12_X13_X14_X15_X16_X17_X18_X19: return AArch64::X12; + case AArch64::X14_X15_X16_X17_X18_X19_X20_X21: return AArch64::X14; + case AArch64::X16_X17_X18_X19_X20_X21_X22_X23: return AArch64::X16; + case AArch64::X18_X19_X20_X21_X22_X23_X24_X25: return AArch64::X18; + case AArch64::X20_X21_X22_X23_X24_X25_X26_X27: return AArch64::X20; + case AArch64::X22_X23_X24_X25_X26_X27_X28_FP: return AArch64::X22; + } + // For anything else, return it unchanged. + return RegTuple; +} + static inline unsigned getBRegFromDReg(unsigned Reg) { switch (Reg) { case AArch64::D0: return AArch64::B0; diff --git a/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll b/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/ls64-inline-asm.ll @@ -0,0 +1,38 @@ +; RUN: llc -mtriple=aarch64 -mattr=+ls64 -verify-machineinstrs -o - %s | FileCheck %s +; RUN: llc -mtriple=aarch64_be -mattr=+ls64 -verify-machineinstrs -o - %s | FileCheck %s + +define void @test_ld64b({ i64, i64, i64, i64, i64, i64, i64, i64 }* %out, i8* %addr) { +; CHECK-LABEL: test_ld64b: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: //APP +; CHECK-NEXT: ld64b x2, [x1] +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: stp x8, x9, [x0, #48] +; CHECK-NEXT: stp x6, x7, [x0, #32] +; CHECK-NEXT: stp x4, x5, [x0, #16] +; CHECK-NEXT: stp x2, x3, [x0] +; CHECK-NEXT: ret +entry: + %val = tail call aarch64_ls64 asm sideeffect "ld64b $0,[$1]", "=r,r"(i8* %addr) + %outcast = bitcast { i64, i64, i64, i64, i64, i64, i64, i64 }* %out to aarch64_ls64* + store aarch64_ls64 %val, aarch64_ls64* %outcast, align 8 + ret void +} + +define void @test_st64b({ i64, i64, i64, i64, i64, i64, i64, i64 }* %in, i8* %addr) { +; CHECK-LABEL: test_st64b: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp x8, x9, [x0, #48] +; CHECK-NEXT: ldp x6, x7, [x0, #32] +; CHECK-NEXT: ldp x4, x5, [x0, #16] +; CHECK-NEXT: ldp x2, x3, [x0] +; CHECK-NEXT: //APP +; CHECK-NEXT: st64b x2, [x1] +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: ret +entry: + %incast = bitcast { i64, i64, i64, i64, i64, i64, i64, i64 }* %in to aarch64_ls64* + %val = load aarch64_ls64, aarch64_ls64* %incast, align 8 + tail call void asm sideeffect "st64b $0,[$1]", "r,r"(aarch64_ls64 %val, i8* %addr) + ret void +}