Skip to content

Commit

Permalink
Merge commit 'ac61cb05cc9dd1b2592056a36f4ebff11d774d29'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 30, 2025
2 parents 3bed1bc + ac61cb0 commit b24d2fe
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 23 deletions.
34 changes: 21 additions & 13 deletions third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,23 @@ CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs,
uint32_t libVersion = 0;
cupti::getVersion<true>(&libVersion);
size_t pcDataSize = sizeof(CUpti_PCSamplingPCData);
// Check cupti api version < 12.4 but cupti header version >= 12.4
// If so, we subtract 4 bytes from the size of CUpti_PCSamplingPCData
// because it introduces a new field (i.e., correlationId) at the end of the
// struct, which is not compatible with the previous versions.
if (libVersion < CUPTI_CUDA12_4_VERSION &&
CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION)
pcDataSize -= CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE;
// Since CUPTI 12.4, a new field (i.e., correlationId) is added to
// CUpti_PCSamplingPCData, which breaks the ABI compatibility.
// Instead of using workarounds, we emit an error message and exit the
// application.
if ((libVersion < CUPTI_CUDA12_4_VERSION &&
CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) ||
(libVersion >= CUPTI_CUDA12_4_VERSION &&
CUPTI_API_VERSION < CUPTI_CUDA12_4_VERSION)) {
throw std::runtime_error(
"[PROTON] CUPTI API version: " + std::to_string(CUPTI_API_VERSION) +
" and CUPTI driver version: " + std::to_string(libVersion) +
" are not compatible. Please set the environment variable "
" TRITON_CUPTI_INCLUDE_PATH and TRITON_CUPTI_LIB_PATH to resolve the "
"problem.");
}
CUpti_PCSamplingData pcSamplingData{
/*size=*/pcDataSize,
/*size=*/sizeof(CUpti_PCSamplingData),
/*collectNumPcs=*/collectNumPCs,
/*totalSamples=*/0,
/*droppedSamples=*/0,
Expand Down Expand Up @@ -372,16 +380,16 @@ void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData,
auto *stallReason = &pcData->stallReason[j];
if (!configureData->stallReasonIndexToMetricIndex.count(
stallReason->pcSamplingStallReasonIndex))
throw std::runtime_error("Invalid stall reason index");
throw std::runtime_error("[PROTON] Invalid stall reason index");
for (auto *data : dataSet) {
auto scopeId = externId;
if (isAPI)
scopeId = data->addOp(externId, lineInfo.functionName);
if (lineInfo.fileName.size())
scopeId = data->addOp(scopeId,
lineInfo.dirName + "/" + lineInfo.fileName +
":" + lineInfo.functionName + "@" +
std::to_string(lineInfo.lineNumber));
scopeId = data->addOp(
scopeId, lineInfo.dirName + "/" + lineInfo.fileName + ":" +
std::to_string(lineInfo.lineNumber) + "@" +
lineInfo.functionName);
auto metricKind = static_cast<PCSamplingMetric::PCSamplingMetricKind>(
configureData->stallReasonIndexToMetricIndex
[stallReason->pcSamplingStallReasonIndex]);
Expand Down
4 changes: 2 additions & 2 deletions third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer,
size_t *maxNumRecords) {
*buffer = static_cast<uint8_t *>(aligned_alloc(AlignSize, BufferSize));
if (*buffer == nullptr) {
throw std::runtime_error("aligned_alloc failed");
throw std::runtime_error("[PROTON] aligned_alloc failed");
}
*bufferSize = BufferSize;
*maxNumRecords = 0;
Expand All @@ -253,7 +253,7 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx,
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
break;
} else {
throw std::runtime_error("cupti::activityGetNextRecord failed");
throw std::runtime_error("[PROTON] cupti::activityGetNextRecord failed");
}
} while (true);

Expand Down
3 changes: 2 additions & 1 deletion third_party/proton/proton/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def format_frames(gf, format):
elif format == "function_line":
gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split(":")[-1])
elif format == "file_function":
gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split("/")[-1].split("@")[0])
gf.dataframe["name"] = gf.dataframe["name"].apply(
lambda x: f"{x.split('/')[-1].split(':')[0]}@{x.split('@')[-1].split(':')[0]}")
return gf


Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/test/examples/frame.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{
"children": [],
"frame": {
"name": "/home/user/projects/example.py/test.py:foo@1",
"name": "/home/user/projects/example.py/test.py:1@foo",
"type": "function"
},
"metrics": {
Expand Down
8 changes: 4 additions & 4 deletions third_party/proton/test/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def test_format_frames(option):
gf, _, _, _ = get_raw_metrics(f)
gf = format_frames(gf, option)
if option == "full":
idx = gf.dataframe["name"] == "/home/user/projects/example.py/test.py:foo@1"
idx = gf.dataframe["name"] == "/home/user/projects/example.py/test.py:1@foo"
elif option == "file_function_line":
idx = gf.dataframe["name"] == "test.py:foo@1"
idx = gf.dataframe["name"] == "test.py:1@foo"
elif option == "function_line":
idx = gf.dataframe["name"] == "foo@1"
idx = gf.dataframe["name"] == "1@foo"
elif option == "file_function":
idx = gf.dataframe["name"] == "test.py:foo"
idx = gf.dataframe["name"] == "test.py@foo"
assert idx.sum() == 1


Expand Down
9 changes: 7 additions & 2 deletions third_party/proton/tutorials/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def grid(META):

argparser = argparse.ArgumentParser()
argparser.add_argument("--profile", action="store_true")
argparser.add_argument("--pcsampling", action="store_true", default=False)
argparser.add_argument("--cudagraph", action="store_true", default=False)
args = argparser.parse_args()

Expand Down Expand Up @@ -305,9 +306,13 @@ def perf(ms):


if args.profile:
proton.start("matmul", hook="triton")
if args.pcsampling:
# proton-viewer -m num_samples/%,time/s ./matmul.hatchet
proton.start("matmul", hook="triton", backend="cupti_pcsampling")
else:
# proton-viewer -m tflop/s,time/s ./matmul.hatchet
proton.start("matmul", hook="triton")
benchmark.run(show_plots=True, print_data=True)
proton.finalize()
# proton-viewer -m tflop/s,time/s ./matmul.hatchet
else:
benchmark.run(show_plots=True, print_data=True)

0 comments on commit b24d2fe

Please sign in to comment.