diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -714,6 +714,7 @@ ID.AddInteger(LD->getMemoryVT().getRawBits()); ID.AddInteger(LD->getRawSubclassData()); ID.AddInteger(LD->getPointerInfo().getAddrSpace()); + ID.AddInteger(LD->getMemOperand()->getFlags()); break; } case ISD::STORE: { @@ -721,6 +722,7 @@ ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); break; } case ISD::VP_LOAD: { @@ -728,6 +730,7 @@ ID.AddInteger(ELD->getMemoryVT().getRawBits()); ID.AddInteger(ELD->getRawSubclassData()); ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + ID.AddInteger(ELD->getMemOperand()->getFlags()); break; } case ISD::VP_STORE: { @@ -735,6 +738,7 @@ ID.AddInteger(EST->getMemoryVT().getRawBits()); ID.AddInteger(EST->getRawSubclassData()); ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + ID.AddInteger(EST->getMemOperand()->getFlags()); break; } case ISD::VP_GATHER: { @@ -742,6 +746,7 @@ ID.AddInteger(EG->getMemoryVT().getRawBits()); ID.AddInteger(EG->getRawSubclassData()); ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + ID.AddInteger(EG->getMemOperand()->getFlags()); break; } case ISD::VP_SCATTER: { @@ -749,6 +754,7 @@ ID.AddInteger(ES->getMemoryVT().getRawBits()); ID.AddInteger(ES->getRawSubclassData()); ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + ID.AddInteger(ES->getMemOperand()->getFlags()); break; } case ISD::MLOAD: { @@ -756,6 +762,7 @@ ID.AddInteger(MLD->getMemoryVT().getRawBits()); ID.AddInteger(MLD->getRawSubclassData()); ID.AddInteger(MLD->getPointerInfo().getAddrSpace()); + ID.AddInteger(MLD->getMemOperand()->getFlags()); break; } case ISD::MSTORE: { @@ -763,6 +770,7 @@ ID.AddInteger(MST->getMemoryVT().getRawBits()); ID.AddInteger(MST->getRawSubclassData()); ID.AddInteger(MST->getPointerInfo().getAddrSpace()); + ID.AddInteger(MST->getMemOperand()->getFlags()); break; } case ISD::MGATHER: { @@ -770,6 +778,7 @@ ID.AddInteger(MG->getMemoryVT().getRawBits()); ID.AddInteger(MG->getRawSubclassData()); ID.AddInteger(MG->getPointerInfo().getAddrSpace()); + ID.AddInteger(MG->getMemOperand()->getFlags()); break; } case ISD::MSCATTER: { @@ -777,6 +786,7 @@ ID.AddInteger(MS->getMemoryVT().getRawBits()); ID.AddInteger(MS->getRawSubclassData()); ID.AddInteger(MS->getPointerInfo().getAddrSpace()); + ID.AddInteger(MS->getMemOperand()->getFlags()); break; } case ISD::ATOMIC_CMP_SWAP: @@ -799,11 +809,13 @@ ID.AddInteger(AT->getMemoryVT().getRawBits()); ID.AddInteger(AT->getRawSubclassData()); ID.AddInteger(AT->getPointerInfo().getAddrSpace()); + ID.AddInteger(AT->getMemOperand()->getFlags()); break; } case ISD::PREFETCH: { const MemSDNode *PF = cast(N); ID.AddInteger(PF->getPointerInfo().getAddrSpace()); + ID.AddInteger(PF->getMemOperand()->getFlags()); break; } case ISD::VECTOR_SHUFFLE: { @@ -823,9 +835,13 @@ } } // end switch (N->getOpcode()) - // Target specific memory nodes could also have address spaces to check. - if (N->isTargetMemoryOpcode()) - ID.AddInteger(cast(N)->getPointerInfo().getAddrSpace()); + // Target specific memory nodes could also have address spaces and flags + // to check. + if (N->isTargetMemoryOpcode()) { + const MemSDNode *MN = cast(N); + ID.AddInteger(MN->getPointerInfo().getAddrSpace()); + ID.AddInteger(MN->getMemOperand()->getFlags()); + } } /// AddNodeIDNode - Generic routine for adding a nodes info to the NodeID @@ -7315,6 +7331,7 @@ ID.AddInteger(MemVT.getRawBits()); AddNodeIDNode(ID, Opcode, VTList, Ops); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void* IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7427,6 +7444,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( Opcode, dl.getIROrder(), VTList, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7599,6 +7617,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, AM, ExtType, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7700,6 +7719,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, ISD::UNINDEXED, false, VT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7766,6 +7786,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, ISD::UNINDEXED, true, SVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7794,6 +7815,7 @@ ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) return SDValue(E, 0); @@ -7851,6 +7873,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -7943,6 +7966,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8013,6 +8037,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8043,6 +8068,7 @@ ID.AddInteger(ST->getMemoryVT().getRawBits()); ID.AddInteger(ST->getRawSubclassData()); ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + ID.AddInteger(ST->getMemOperand()->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) return SDValue(E, 0); @@ -8070,6 +8096,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8113,6 +8140,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8162,6 +8190,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, AM, ExtTy, isExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8209,6 +8238,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, AM, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8250,6 +8280,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, MemVT, MMO, IndexType, ExtTy)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); @@ -8297,6 +8328,7 @@ ID.AddInteger(getSyntheticNodeSubclassData( dl.getIROrder(), VTs, MemVT, MMO, IndexType, IsTrunc)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); diff --git a/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll b/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/mmo-target-flags-folding.ll @@ -0,0 +1,26 @@ +; RUN: llc -march=amdgcn -mcpu=gfx900 < %s | FileCheck --check-prefix=GCN %s + +; This is used to crash due to mismatch of MMO target flags when folding +; a LOAD SDNodes with different flags. + +; GCN-LABEL: {{^}}test_load_folding_mmo_flags: +; GCN: global_load_dwordx2 +define amdgpu_kernel void @test_load_folding_mmo_flags(<2 x float> addrspace(1)* %arg) { +entry: + %id = tail call i32 @llvm.amdgcn.workitem.id.x() + %arrayidx = getelementptr inbounds <2 x float>, <2 x float> addrspace(1)* %arg, i32 %id + %i1 = bitcast <2 x float> addrspace(1)* %arrayidx to i64 addrspace(1)* + %i2 = getelementptr <2 x float>, <2 x float> addrspace(1)* %arrayidx, i64 0, i32 0 + %i3 = load float, float addrspace(1)* %i2, align 4 + %idx = getelementptr inbounds <2 x float>, <2 x float> addrspace(1)* %arrayidx, i64 0, i32 1 + %i4 = load float, float addrspace(1)* %idx, align 4 + %i5 = load i64, i64 addrspace(1)* %i1, align 4, !amdgpu.noclobber !0 + store i64 %i5, i64 addrspace(1)* undef, align 4 + %mul = fmul float %i3, %i4 + store float %mul, float addrspace(1)* undef, align 4 + unreachable +} + +declare i32 @llvm.amdgcn.workitem.id.x() + +!0 = !{}