Skip to content

Commit

Permalink
[SDAG] Fixups required for InferAS change
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexMaclean committed Jan 14, 2025
1 parent 7b849a6 commit c6db77f
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 50 deletions.
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,12 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
ID.AddInteger(M);
break;
}
case ISD::ADDRSPACECAST: {
const AddrSpaceCastSDNode *ASC = cast<AddrSpaceCastSDNode>(N);
ID.AddInteger(ASC->getSrcAddressSpace());
ID.AddInteger(ASC->getDestAddressSpace());
break;
}
case ISD::TargetBlockAddress:
case ISD::BlockAddress: {
const BlockAddressSDNode *BA = cast<BlockAddressSDNode>(N);
Expand Down
48 changes: 27 additions & 21 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Target/TargetIntrinsicInfo.h"
#include <optional>

using namespace llvm;

Expand Down Expand Up @@ -334,29 +335,34 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
return true;
}

static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();

if (!Src)
static std::optional<unsigned> convertAS(unsigned AS) {
switch (AS) {
case llvm::ADDRESS_SPACE_LOCAL:
return NVPTX::AddressSpace::Local;
case llvm::ADDRESS_SPACE_GLOBAL:
return NVPTX::AddressSpace::Global;
case llvm::ADDRESS_SPACE_SHARED:
return NVPTX::AddressSpace::Shared;
case llvm::ADDRESS_SPACE_GENERIC:
return NVPTX::AddressSpace::Generic;

if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
switch (PT->getAddressSpace()) {
case llvm::ADDRESS_SPACE_LOCAL:
return NVPTX::AddressSpace::Local;
case llvm::ADDRESS_SPACE_GLOBAL:
return NVPTX::AddressSpace::Global;
case llvm::ADDRESS_SPACE_SHARED:
return NVPTX::AddressSpace::Shared;
case llvm::ADDRESS_SPACE_GENERIC:
return NVPTX::AddressSpace::Generic;
case llvm::ADDRESS_SPACE_PARAM:
return NVPTX::AddressSpace::Param;
case llvm::ADDRESS_SPACE_CONST:
return NVPTX::AddressSpace::Const;
default: break;
}
case llvm::ADDRESS_SPACE_PARAM:
return NVPTX::AddressSpace::Param;
case llvm::ADDRESS_SPACE_CONST:
return NVPTX::AddressSpace::Const;
default:
return std::nullopt;
}
}

static unsigned int getCodeAddrSpace(const MemSDNode *N) {
if (const Value *Src = N->getMemOperand()->getValue())
if (auto *PT = dyn_cast<PointerType>(Src->getType()))
if (auto AS = convertAS(PT->getAddressSpace()))
return AS.value();

if (auto AS = convertAS(N->getMemOperand()->getAddrSpace()))
return AS.value();

return NVPTX::AddressSpace::Generic;
}

Expand Down
22 changes: 18 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
return false;
}

static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
const DataLayout &DL,
const TargetLowering &TL) {
if (Ptr->getOpcode() == ISD::FrameIndex) {
auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
ADDRESS_SPACE_LOCAL);

return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
}
return MachinePointerInfo();
}

SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {

Expand Down Expand Up @@ -1562,11 +1575,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}

if (IsByVal) {
auto PtrVT = getPointerTy(DL);
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
auto MPI = refinePtrAS(StVal, DAG, DL, *this);
const EVT PtrVT = StVal.getValueType();
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
DAG.getConstant(CurOffset, dl, PtrVT));
StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
PartAlign);

StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
} else if (ExtendIntegerParam) {
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
// zext/sext to i32
Expand Down
20 changes: 11 additions & 9 deletions llvm/test/CodeGen/NVPTX/indirect_byval.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@ define internal i32 @foo() {
; CHECK-NEXT: .reg .b64 %SPL;
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
; CHECK-NEXT: add.u64 %rd3, %SPL, 1;
; CHECK-NEXT: ld.local.u8 %rs1, [%rd3];
; CHECK-NEXT: add.u64 %rd4, %SP, 0;
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 1 .b8 param0[1];
; CHECK-NEXT: st.param.b8 [param0], %rs1;
; CHECK-NEXT: .param .b64 param1;
; CHECK-NEXT: st.param.b64 [param1], %rd2;
; CHECK-NEXT: st.param.b64 [param1], %rd4;
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
; CHECK-NEXT: call (retval0),
Expand Down Expand Up @@ -59,19 +60,20 @@ define internal i32 @bar() {
; CHECK-NEXT: .reg .b64 %SP;
; CHECK-NEXT: .reg .b64 %SPL;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-NEXT: .reg .b64 %rd<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
; CHECK-NEXT: add.u64 %rd3, %SPL, 8;
; CHECK-NEXT: ld.local.u64 %rd4, [%rd3];
; CHECK-NEXT: add.u64 %rd5, %SP, 0;
; CHECK-NEXT: { // callseq 1, 0
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: st.param.b64 [param0], %rd2;
; CHECK-NEXT: st.param.b64 [param0], %rd4;
; CHECK-NEXT: .param .b64 param1;
; CHECK-NEXT: st.param.b64 [param1], %rd3;
; CHECK-NEXT: st.param.b64 [param1], %rd5;
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
; CHECK-NEXT: call (retval0),
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/NVPTX/variadics-backend.ll
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ define dso_local void @qux() {
; CHECK-PTX-NEXT: st.local.u64 [%rd2+8], %rd6;
; CHECK-PTX-NEXT: mov.b64 %rd7, 1;
; CHECK-PTX-NEXT: st.u64 [%SP+16], %rd7;
; CHECK-PTX-NEXT: ld.u64 %rd8, [%SP];
; CHECK-PTX-NEXT: ld.u64 %rd9, [%SP+8];
; CHECK-PTX-NEXT: ld.local.u64 %rd8, [%rd2];
; CHECK-PTX-NEXT: ld.local.u64 %rd9, [%rd2+8];
; CHECK-PTX-NEXT: add.u64 %rd10, %SP, 16;
; CHECK-PTX-NEXT: { // callseq 3, 0
; CHECK-PTX-NEXT: .param .align 8 .b8 param0[16];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
; CHECK-NEXT: .reg .b32 %SP;
; CHECK-NEXT: .reg .b32 %SPL;
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-NEXT: .reg .b64 %rd<17>;
; CHECK-NEXT: .reg .b64 %rd<13>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: mov.u32 %SPL, __local_depot0;
; CHECK-NEXT: cvta.local.u32 %SP, %SPL;
; CHECK-NEXT: ld.param.u32 %r1, [caller_St8x4_param_1];
; CHECK-NEXT: add.u32 %r3, %SPL, 0;
; CHECK-NEXT: ld.param.u64 %rd1, [caller_St8x4_param_0+24];
Expand All @@ -25,27 +24,23 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
; CHECK-NEXT: st.local.u64 [%r3+8], %rd3;
; CHECK-NEXT: ld.param.u64 %rd4, [caller_St8x4_param_0];
; CHECK-NEXT: st.local.u64 [%r3], %rd4;
; CHECK-NEXT: ld.u64 %rd5, [%SP+8];
; CHECK-NEXT: ld.u64 %rd6, [%SP];
; CHECK-NEXT: ld.u64 %rd7, [%SP+24];
; CHECK-NEXT: ld.u64 %rd8, [%SP+16];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 16 .b8 param0[32];
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd6, %rd5};
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd8, %rd7};
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd4, %rd3};
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd2, %rd1};
; CHECK-NEXT: .param .align 16 .b8 retval0[32];
; CHECK-NEXT: call.uni (retval0),
; CHECK-NEXT: callee_St8x4,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: ld.param.v2.b64 {%rd9, %rd10}, [retval0];
; CHECK-NEXT: ld.param.v2.b64 {%rd11, %rd12}, [retval0+16];
; CHECK-NEXT: ld.param.v2.b64 {%rd5, %rd6}, [retval0];
; CHECK-NEXT: ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: st.u64 [%r1], %rd9;
; CHECK-NEXT: st.u64 [%r1+8], %rd10;
; CHECK-NEXT: st.u64 [%r1+16], %rd11;
; CHECK-NEXT: st.u64 [%r1+24], %rd12;
; CHECK-NEXT: st.u64 [%r1], %rd5;
; CHECK-NEXT: st.u64 [%r1+8], %rd6;
; CHECK-NEXT: st.u64 [%r1+16], %rd7;
; CHECK-NEXT: st.u64 [%r1+24], %rd8;
; CHECK-NEXT: ret;
%call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
%.fca.0.extract = extractvalue [4 x i64] %call, 0
Expand Down

0 comments on commit c6db77f

Please sign in to comment.