Skip to content

Commit

Permalink
Merge pull request #77 from zsogitbe/master
Browse files Browse the repository at this point in the history
Normalizing and simplifying Cuda precompilation logging
  • Loading branch information
zhongkaifu authored Oct 20, 2023
2 parents 53d517b + 101731d commit 3e99859
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 27 deletions.
122 changes: 107 additions & 15 deletions AdvUtils/Logger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,63 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace AdvUtils
{
/// <summary>
/// Progress Callback delegate with three functionalities:
/// 1. post a callback message to the caller routine
/// 2. post a callback progress value (%) to the caller routine for long operations
/// 3. Signal if the long process must be canceled for stopping long operations on request of the caller
/// </summary>
/// <param name="value">progress value in % (0-100)</param>
/// <param name="log">progress message</param>
/// <param name="type">type of message, for example, 0: log, 1: error, etc.</param>
/// <param name="color">request a specific color for the message</param>
/// <returns>+1 if the process should be canceled, -1 if not</returns>
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public delegate int ProgressCallback(
int value,
StringBuilder log,
int type,
int color = 0
);

public class Logger
{
public enum Level { err, warn, info};
public enum LogVerbose {None, Normal, Details, Debug };

public enum LogVerbose {None, Normal, Details, Debug, Callback, Logfileonly, Progress };

public static LogVerbose Verbose = LogVerbose.Normal;

private static ProgressCallback? s_callback = null;

private static LogVerbose s_logverbosebackup = LogVerbose.Normal;

/// <summary>
/// Set the callback routine in your code with this to automatically redirect all messages to your callback function
/// </summary>
public static ProgressCallback? Callback
{
get => s_callback;
set
{
s_callback = value;
if (s_callback != null)
{
s_logverbosebackup = Verbose;
Verbose = LogVerbose.Callback;
}
else
{
Verbose = s_logverbosebackup;
}
}
}

public static void WriteLine(string s, params object[] args)
{
if (Verbose == LogVerbose.None)
Expand Down Expand Up @@ -40,10 +86,27 @@ public static void WriteLine(Level level, string s, params object[] args)

string sLine = sb.ToString();

if (level != Level.info)
Console.Error.WriteLine(sLine);
else
Console.WriteLine(sLine);
if (Callback != null && Verbose == LogVerbose.Callback)
{ // let the caller handle the message
StringBuilder sbl = new StringBuilder(sLine);
Callback(0, sbl, (int)level);
}
else if (Callback != null && Verbose == LogVerbose.Progress)
{ // inform the caller about the progress
if (args.Length > 0)
{
StringBuilder sbl0 = new StringBuilder("");
Callback((int)args[0], sbl0, (int)level);
return;
}
}
else if (Verbose != LogVerbose.Logfileonly)
{ // only print on the Console if Logfileonly is not requested
if (level != Level.info)
Console.Error.WriteLine(sLine);
else
Console.WriteLine(sLine);
}

try
{
Expand All @@ -52,10 +115,17 @@ public static void WriteLine(Level level, string s, params object[] args)
}
catch (Exception err)
{
Console.Error.WriteLine($"Failed to output log to file '{LogFile}'. Error = '{err.Message}'");
if (Callback != null && Verbose == LogVerbose.Callback)
{ // let the caller handle the message
StringBuilder sbl = new StringBuilder($"Failed to write to log file '{LogFile}'. Error = '{err.Message}'");
Callback(0, sbl, (int)level);
}
else
{
Console.Error.WriteLine($"Failed to write to log file '{LogFile}'. Error = '{err.Message}'");
}
s_sw = null;
}

}

public static void WriteLine(Level level, ConsoleColor color, string s, params object[] args)
Expand All @@ -75,14 +145,28 @@ public static void WriteLine(Level level, ConsoleColor color, string s, params o

string sLine = sb.ToString();

Console.ForegroundColor = color;
if (Callback != null && Verbose == LogVerbose.Callback)
{ // let the caller handle the message
StringBuilder sbl = new StringBuilder(sLine);
Callback(0, sbl, (int)level, (int)color);
}
else if (Callback != null && Verbose == LogVerbose.Progress)
{ // inform the caller about the progress
StringBuilder sbl0 = new StringBuilder("");
Callback((int)args[0], sbl0, (int)level);
return;
}
else if (Verbose != LogVerbose.Logfileonly)
{ // only print on the Console if Logfileonly is not requested
Console.ForegroundColor = color;

if (level != Level.info)
Console.Error.WriteLine(sLine);
else
Console.WriteLine(sLine);
if (level != Level.info)
Console.Error.WriteLine(sLine);
else
Console.WriteLine(sLine);

Console.ResetColor();
Console.ResetColor();
}

try
{
Expand All @@ -91,10 +175,18 @@ public static void WriteLine(Level level, ConsoleColor color, string s, params o
}
catch (Exception err)
{
Console.Error.WriteLine($"Failed to output log to file '{LogFile}'. Error = '{err.Message}'");
if (Callback != null && Verbose == LogVerbose.Callback)
{ // let the caller handle the message
StringBuilder sbl = new StringBuilder($"Failed to write to log file '{LogFile}'. Error = '{err.Message}'");
Callback(0, sbl, (int)level);
}
else
{
Console.Error.WriteLine($"Failed to write to log file '{LogFile}'. Error = '{err.Message}'");
}

s_sw = null;
}

}

public static void Close()
Expand Down
7 changes: 6 additions & 1 deletion ExternalProjects/managedCuda/ManagedCUDA/BasicTypesEnum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,12 @@ public enum CUJITTarget
/// <summary>
/// Compute device class 8.6.
/// </summary>
Compute_86 = 86
Compute_86 = 86,

/// <summary>
/// Compute device class 8.9.
/// </summary>
Compute_89 = 89
}

/// <summary>
Expand Down
25 changes: 25 additions & 0 deletions Seq2SeqSharp/Utils/Misc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using TensorSharp;
using M = System.Runtime.CompilerServices.MethodImplAttribute;
using O = System.Runtime.CompilerServices.MethodImplOptions;
using ManagedCuda;

namespace Seq2SeqSharp.Utils
{
Expand Down Expand Up @@ -51,6 +52,30 @@ public static string GetTimeStamp(DateTime timeStamp)
{
return string.Format("{0:yyyy}_{0:MM}_{0:dd}_{0:HH}h_{0:mm}m_{0:ss}s", timeStamp);
}

/// <summary>
/// Get the number of GPU's or CPU cores in the system
/// </summary>
/// <param name="GPU">true: get the number of GPUs in the system (default), false: get the number of CPU cores in the system</param>
/// <returns>number of GPUs or CPU cores in the system</returns>
public static int GetDeviceCount(bool GPU = true)
{
try
{
if (GPU)
{
return CudaContext.GetDeviceCount();
}
else
{
return Environment.ProcessorCount;
}
}
catch (Exception)
{
return 0;
}
}
}

public static class Misc
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Utils/TensorAllocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static void InitDevices(ProcessorTypeEnums archType, int[] ids, float mem
if (m_archType == ProcessorTypeEnums.GPU)
{
m_cudaContext = new TSCudaContext(m_deviceIds, memoryUsageRatio, compilerOptions, allocatorType, elementType);
m_cudaContext.Precompile(Console.Write);
m_cudaContext.Precompile();
m_cudaContext.CleanUnusedPTX();

foreach (int deviceId in m_deviceIds)
Expand Down
5 changes: 3 additions & 2 deletions TensorSharp.CUDA/PrecompileAttribute.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using AdvUtils;
using System;
using System.Reflection;
using TensorSharp.CUDA.RuntimeCompiler;

Expand Down Expand Up @@ -28,7 +29,7 @@ public static void PrecompileAllFields(object instance, CudaCompiler compiler)
if (typeof(IPrecompilable).IsAssignableFrom(field.FieldType))
{
IPrecompilable precompilableField = (IPrecompilable)field.GetValue(instance);
Console.WriteLine("Compiling field " + field.Name);
Logger.WriteLine("Compiling field " + field.Name);
precompilableField.Precompile(compiler);
}
}
Expand Down
11 changes: 3 additions & 8 deletions TensorSharp.CUDA/TSCudaContext.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ManagedCuda;
using AdvUtils;
using ManagedCuda;
using ManagedCuda.BasicTypes;
using ManagedCuda.CudaBlas;
using System;
Expand Down Expand Up @@ -189,18 +190,12 @@ private static bool EnablePeers(CudaContext src, CudaContext target)
}
}


public void Precompile()
{
Precompile(Console.Write);
}

public void Precompile(Action<string> precompileProgressWriter)
{
Assembly assembly = Assembly.GetExecutingAssembly();
foreach (Tuple<Type, IEnumerable<PrecompileAttribute>> applyType in assembly.TypesWithAttribute<PrecompileAttribute>(true).Where(x => !x.Item1.IsAbstract))
{
precompileProgressWriter("Precompiling " + applyType.Item1.Name + "\n");
Logger.WriteLine("Precompiling " + applyType.Item1.Name);

IPrecompilable instance = (IPrecompilable)Activator.CreateInstance(applyType.Item1);
instance.Precompile(Compiler);
Expand Down

0 comments on commit 3e99859

Please sign in to comment.