From 14df45a9d4b0c4016550419b5e9516659fa1fa35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Szo=C5=82ucha?= Date: Mon, 20 Nov 2023 21:13:28 +0100 Subject: [PATCH] Enhance NVTX analysis capabilities (#106) * Add a Domain to NVTX analysis Signed-off-by: szalpal * Fix pre-commit Signed-off-by: szalpal * Add missing option to CMake Signed-off-by: szalpal * Working around the problem of unlined CUDA Signed-off-by: szalpal * Missing ; Signed-off-by: szalpal * More compilation fixes Signed-off-by: szalpal --------- Signed-off-by: szalpal --- include/triton/common/nvtx.h | 73 ++++++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 7 deletions(-) diff --git a/include/triton/common/nvtx.h b/include/triton/common/nvtx.h index 450736c..95635e2 100644 --- a/include/triton/common/nvtx.h +++ b/include/triton/common/nvtx.h @@ -1,4 +1,4 @@ -// Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -31,14 +31,61 @@ namespace triton { namespace common { +namespace detail { + +class NvtxTritonDomain { + public: + static nvtxDomainHandle_t& GetDomain() + { + static NvtxTritonDomain inst; + return inst.triton_nvtx_domain_; + } + + private: + NvtxTritonDomain() { triton_nvtx_domain_ = nvtxDomainCreateA("Triton"); } + + ~NvtxTritonDomain() { nvtxDomainDestroy(triton_nvtx_domain_); } + + nvtxDomainHandle_t triton_nvtx_domain_; +}; + +} // namespace detail + // Updates a server stat with duration measured by a C++ scope. class NvtxRange { public: - explicit NvtxRange(const char* label) { nvtxRangePushA(label); } + explicit NvtxRange(const char* label, uint32_t rgb = kNvGreen) + { + auto attr = GetAttributes(label, rgb); + nvtxDomainRangePushEx(detail::NvtxTritonDomain::GetDomain(), &attr); + } + + explicit NvtxRange(const std::string& label, uint32_t rgb = kNvGreen) + : NvtxRange(label.c_str(), rgb) + { + } - explicit NvtxRange(const std::string& label) : NvtxRange(label.c_str()) {} + ~NvtxRange() { nvtxDomainRangePop(detail::NvtxTritonDomain::GetDomain()); } - ~NvtxRange() { nvtxRangePop(); } + static constexpr uint32_t kNvGreen = 0x76b900; + static constexpr uint32_t kRed = 0xc1121f; + static constexpr uint32_t kGreen = 0x588157; + static constexpr uint32_t kBlue = 0x023047; + static constexpr uint32_t kYellow = 0xffb703; + static constexpr uint32_t kOrange = 0xfb8500; + + private: + nvtxEventAttributes_t GetAttributes(const char* label, uint32_t rgb) + { + nvtxEventAttributes_t attr; + attr.version = NVTX_VERSION; + attr.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + attr.colorType = NVTX_COLOR_ARGB; + attr.color = rgb | 0xff000000; + attr.messageType = NVTX_MESSAGE_TYPE_ASCII; + attr.message.ascii = label; + return attr; + } }; }} // namespace triton::common @@ -46,14 +93,26 @@ class NvtxRange { #endif // TRITON_ENABLE_NVTX // -// Macros to access NVTX functionality +// Macros to access NVTX functionality. +// For `NVTX_RANGE` macro please refer to the usage below. // #ifdef TRITON_ENABLE_NVTX #define NVTX_INITIALIZE nvtxInitialize(nullptr) -#define NVTX_RANGE(V, L) triton::common::NvtxRange V(L) +#define NVTX_RANGE1(V, L) triton::common::NvtxRange V(L) +#define NVTX_RANGE2(V, L, RGB) triton::common::NvtxRange V(L, RGB) #define NVTX_MARKER(L) nvtxMarkA(L) #else #define NVTX_INITIALIZE -#define NVTX_RANGE(V, L) +#define NVTX_RANGE1(V, L) +#define NVTX_RANGE2(V, L, RGB) #define NVTX_MARKER(L) #endif // TRITON_ENABLE_NVTX + +// "Overload" for `NVTX_RANGE` macro. +// Usage: +// NVTX_RANGE(nvtx1, "My message") -> Records NVTX marker with kNvGreen color. +// NVTX_RANGE(nvtx1, "My message", NvtxRange::kRed) -> Records NVTX marker with +// kRed color. +#define GET_NVTX_MACRO(_1, _2, _3, NAME, ...) NAME +#define NVTX_RANGE(...) \ + GET_NVTX_MACRO(__VA_ARGS__, NVTX_RANGE2, NVTX_RANGE1)(__VA_ARGS__)