diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 0dfd0302ae5438..743ae4895a1b1c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -954,6 +954,12 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { ID.AddInteger(M); break; } + case ISD::ADDRSPACECAST: { + const AddrSpaceCastSDNode *ASC = cast(N); + ID.AddInteger(ASC->getSrcAddressSpace()); + ID.AddInteger(ASC->getDestAddressSpace()); + break; + } case ISD::TargetBlockAddress: case ISD::BlockAddress: { const BlockAddressSDNode *BA = cast(N); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 2e66b67dfdcc76..b3bca33e5d0fec 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -24,6 +24,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Target/TargetIntrinsicInfo.h" +#include using namespace llvm; @@ -334,29 +335,34 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) { return true; } -static unsigned int getCodeAddrSpace(MemSDNode *N) { - const Value *Src = N->getMemOperand()->getValue(); - - if (!Src) +static std::optional convertAS(unsigned AS) { + switch (AS) { + case llvm::ADDRESS_SPACE_LOCAL: + return NVPTX::AddressSpace::Local; + case llvm::ADDRESS_SPACE_GLOBAL: + return NVPTX::AddressSpace::Global; + case llvm::ADDRESS_SPACE_SHARED: + return NVPTX::AddressSpace::Shared; + case llvm::ADDRESS_SPACE_GENERIC: return NVPTX::AddressSpace::Generic; - - if (auto *PT = dyn_cast(Src->getType())) { - switch (PT->getAddressSpace()) { - case llvm::ADDRESS_SPACE_LOCAL: - return NVPTX::AddressSpace::Local; - case llvm::ADDRESS_SPACE_GLOBAL: - return NVPTX::AddressSpace::Global; - case llvm::ADDRESS_SPACE_SHARED: - return NVPTX::AddressSpace::Shared; - case llvm::ADDRESS_SPACE_GENERIC: - return NVPTX::AddressSpace::Generic; - case llvm::ADDRESS_SPACE_PARAM: - return NVPTX::AddressSpace::Param; - case llvm::ADDRESS_SPACE_CONST: - return NVPTX::AddressSpace::Const; - default: break; - } + case llvm::ADDRESS_SPACE_PARAM: + return NVPTX::AddressSpace::Param; + case llvm::ADDRESS_SPACE_CONST: + return NVPTX::AddressSpace::Const; + default: + return std::nullopt; } +} + +static unsigned int getCodeAddrSpace(const MemSDNode *N) { + if (const Value *Src = N->getMemOperand()->getValue()) + if (auto *PT = dyn_cast(Src->getType())) + if (auto AS = convertAS(PT->getAddressSpace())) + return AS.value(); + + if (auto AS = convertAS(N->getMemOperand()->getAddrSpace())) + return AS.value(); + return NVPTX::AddressSpace::Generic; } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 208d724f7ae283..916d8cbeb5d6e1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1398,6 +1398,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB, return false; } +static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG, + const DataLayout &DL, + const TargetLowering &TL) { + if (Ptr->getOpcode() == ISD::FrameIndex) { + auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL); + Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC, + ADDRESS_SPACE_LOCAL); + + return MachinePointerInfo(ADDRESS_SPACE_LOCAL); + } + return MachinePointerInfo(); +} + SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVectorImpl &InVals) const { @@ -1562,11 +1575,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } if (IsByVal) { - auto PtrVT = getPointerTy(DL); - SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, + auto MPI = refinePtrAS(StVal, DAG, DL, *this); + const EVT PtrVT = StVal.getValueType(); + SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal, DAG.getConstant(CurOffset, dl, PtrVT)); - StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(), - PartAlign); + + StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign); } else if (ExtendIntegerParam) { assert(VTs.size() == 1 && "Scalar can't have multiple parts."); // zext/sext to i32 diff --git a/llvm/test/CodeGen/NVPTX/indirect_byval.ll b/llvm/test/CodeGen/NVPTX/indirect_byval.ll index d6c6e032f032fd..3ae6300d8767d6 100644 --- a/llvm/test/CodeGen/NVPTX/indirect_byval.ll +++ b/llvm/test/CodeGen/NVPTX/indirect_byval.ll @@ -17,19 +17,20 @@ define internal i32 @foo() { ; CHECK-NEXT: .reg .b64 %SPL; ; CHECK-NEXT: .reg .b16 %rs<2>; ; CHECK-NEXT: .reg .b32 %r<3>; -; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-NEXT: .reg .b64 %rd<5>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: // %entry ; CHECK-NEXT: mov.u64 %SPL, __local_depot0; ; CHECK-NEXT: cvta.local.u64 %SP, %SPL; ; CHECK-NEXT: ld.global.u64 %rd1, [ptr]; -; CHECK-NEXT: ld.u8 %rs1, [%SP+1]; -; CHECK-NEXT: add.u64 %rd2, %SP, 0; +; CHECK-NEXT: add.u64 %rd3, %SPL, 1; +; CHECK-NEXT: ld.local.u8 %rs1, [%rd3]; +; CHECK-NEXT: add.u64 %rd4, %SP, 0; ; CHECK-NEXT: { // callseq 0, 0 ; CHECK-NEXT: .param .align 1 .b8 param0[1]; ; CHECK-NEXT: st.param.b8 [param0], %rs1; ; CHECK-NEXT: .param .b64 param1; -; CHECK-NEXT: st.param.b64 [param1], %rd2; +; CHECK-NEXT: st.param.b64 [param1], %rd4; ; CHECK-NEXT: .param .b32 retval0; ; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _); ; CHECK-NEXT: call (retval0), @@ -59,19 +60,20 @@ define internal i32 @bar() { ; CHECK-NEXT: .reg .b64 %SP; ; CHECK-NEXT: .reg .b64 %SPL; ; CHECK-NEXT: .reg .b32 %r<3>; -; CHECK-NEXT: .reg .b64 %rd<4>; +; CHECK-NEXT: .reg .b64 %rd<6>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: // %entry ; CHECK-NEXT: mov.u64 %SPL, __local_depot1; ; CHECK-NEXT: cvta.local.u64 %SP, %SPL; ; CHECK-NEXT: ld.global.u64 %rd1, [ptr]; -; CHECK-NEXT: ld.u64 %rd2, [%SP+8]; -; CHECK-NEXT: add.u64 %rd3, %SP, 0; +; CHECK-NEXT: add.u64 %rd3, %SPL, 8; +; CHECK-NEXT: ld.local.u64 %rd4, [%rd3]; +; CHECK-NEXT: add.u64 %rd5, %SP, 0; ; CHECK-NEXT: { // callseq 1, 0 ; CHECK-NEXT: .param .align 8 .b8 param0[8]; -; CHECK-NEXT: st.param.b64 [param0], %rd2; +; CHECK-NEXT: st.param.b64 [param0], %rd4; ; CHECK-NEXT: .param .b64 param1; -; CHECK-NEXT: st.param.b64 [param1], %rd3; +; CHECK-NEXT: st.param.b64 [param1], %rd5; ; CHECK-NEXT: .param .b32 retval0; ; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _); ; CHECK-NEXT: call (retval0), diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll index f5c1e238f553a5..c3296dd5298fc0 100644 --- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll +++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll @@ -397,8 +397,8 @@ define dso_local void @qux() { ; CHECK-PTX-NEXT: st.local.u64 [%rd2+8], %rd6; ; CHECK-PTX-NEXT: mov.b64 %rd7, 1; ; CHECK-PTX-NEXT: st.u64 [%SP+16], %rd7; -; CHECK-PTX-NEXT: ld.u64 %rd8, [%SP]; -; CHECK-PTX-NEXT: ld.u64 %rd9, [%SP+8]; +; CHECK-PTX-NEXT: ld.local.u64 %rd8, [%rd2]; +; CHECK-PTX-NEXT: ld.local.u64 %rd9, [%rd2+8]; ; CHECK-PTX-NEXT: add.u64 %rd10, %SP, 16; ; CHECK-PTX-NEXT: { // callseq 3, 0 ; CHECK-PTX-NEXT: .param .align 8 .b8 param0[16]; diff --git a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected index b0346f4db5ba19..820ade631dd640 100644 --- a/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected +++ b/llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected @@ -10,11 +10,10 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct ; CHECK-NEXT: .reg .b32 %SP; ; CHECK-NEXT: .reg .b32 %SPL; ; CHECK-NEXT: .reg .b32 %r<4>; -; CHECK-NEXT: .reg .b64 %rd<17>; +; CHECK-NEXT: .reg .b64 %rd<13>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: mov.u32 %SPL, __local_depot0; -; CHECK-NEXT: cvta.local.u32 %SP, %SPL; ; CHECK-NEXT: ld.param.u32 %r1, [caller_St8x4_param_1]; ; CHECK-NEXT: add.u32 %r3, %SPL, 0; ; CHECK-NEXT: ld.param.u64 %rd1, [caller_St8x4_param_0+24]; @@ -25,27 +24,23 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct ; CHECK-NEXT: st.local.u64 [%r3+8], %rd3; ; CHECK-NEXT: ld.param.u64 %rd4, [caller_St8x4_param_0]; ; CHECK-NEXT: st.local.u64 [%r3], %rd4; -; CHECK-NEXT: ld.u64 %rd5, [%SP+8]; -; CHECK-NEXT: ld.u64 %rd6, [%SP]; -; CHECK-NEXT: ld.u64 %rd7, [%SP+24]; -; CHECK-NEXT: ld.u64 %rd8, [%SP+16]; ; CHECK-NEXT: { // callseq 0, 0 ; CHECK-NEXT: .param .align 16 .b8 param0[32]; -; CHECK-NEXT: st.param.v2.b64 [param0], {%rd6, %rd5}; -; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd8, %rd7}; +; CHECK-NEXT: st.param.v2.b64 [param0], {%rd4, %rd3}; +; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd2, %rd1}; ; CHECK-NEXT: .param .align 16 .b8 retval0[32]; ; CHECK-NEXT: call.uni (retval0), ; CHECK-NEXT: callee_St8x4, ; CHECK-NEXT: ( ; CHECK-NEXT: param0 ; CHECK-NEXT: ); -; CHECK-NEXT: ld.param.v2.b64 {%rd9, %rd10}, [retval0]; -; CHECK-NEXT: ld.param.v2.b64 {%rd11, %rd12}, [retval0+16]; +; CHECK-NEXT: ld.param.v2.b64 {%rd5, %rd6}, [retval0]; +; CHECK-NEXT: ld.param.v2.b64 {%rd7, %rd8}, [retval0+16]; ; CHECK-NEXT: } // callseq 0 -; CHECK-NEXT: st.u64 [%r1], %rd9; -; CHECK-NEXT: st.u64 [%r1+8], %rd10; -; CHECK-NEXT: st.u64 [%r1+16], %rd11; -; CHECK-NEXT: st.u64 [%r1+24], %rd12; +; CHECK-NEXT: st.u64 [%r1], %rd5; +; CHECK-NEXT: st.u64 [%r1+8], %rd6; +; CHECK-NEXT: st.u64 [%r1+16], %rd7; +; CHECK-NEXT: st.u64 [%r1+24], %rd8; ; CHECK-NEXT: ret; %call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2 %.fca.0.extract = extractvalue [4 x i64] %call, 0