Skip to content

Autocast #1235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 51 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
51d1d95
Autocast
haytham2597 Feb 11, 2024
29b4900
Added some features
haytham2597 Feb 17, 2024
defd582
Fix mistake gitignore
haytham2597 Feb 18, 2024
d532402
AMP
haytham2597 Feb 18, 2024
0b839db
Add Print Modules Still in progress
haytham2597 Feb 19, 2024
98cabfa
Add some printing module
haytham2597 Feb 19, 2024
669b4fa
Fix some dotnet build. Need fix tests
haytham2597 Feb 20, 2024
3940414
Fast tensor accessor for ToArray()
haytham2597 Jun 30, 2024
3469d7a
Update local
haytham2597 Jun 30, 2024
5062339
fix local build dotnet
haytham2597 Jun 30, 2024
3a467af
Fast ToArray() TensorAccessor
haytham2597 Jul 2, 2024
18c7528
Fast tensor accesor
haytham2597 Jul 2, 2024
728c9fb
fix accesor for every types
haytham2597 Jul 9, 2024
a9a611a
GradScaler
haytham2597 Jul 12, 2024
4a406ec
Trying fix build for azure
haytham2597 Jul 14, 2024
280c8d5
Range sequential
haytham2597 Jul 17, 2024
3c42a87
AMPManager
haytham2597 Jul 19, 2024
7cd7f9c
Amp
haytham2597 Jul 20, 2024
1293483
update
haytham2597 Jul 20, 2024
0c2769a
fix azure devops?
haytham2597 Jul 21, 2024
eafdd1e
fix test?
haytham2597 Jul 21, 2024
c0883d9
fix mac test?
haytham2597 Jul 21, 2024
9ac78bd
AMP Problem outscope
haytham2597 Jul 24, 2024
d6a0c28
gradscale, device cuda properties, etc.
haytham2597 Sep 3, 2024
21ce055
some gradscaler. Need grad_scale and found_inf attr in optimizer
haytham2597 Sep 3, 2024
e9f34c8
Merge branch 'main' of https://github.com/dotnet/TorchSharp
haytham2597 Sep 3, 2024
c70b523
update v2.4.0
haytham2597 Sep 3, 2024
36b79b9
some advance
haytham2597 Sep 5, 2024
376f4fb
Improve autocastmode
haytham2597 Sep 8, 2024
9f4a48b
Some Autocast f16, f32
haytham2597 Oct 18, 2024
f84392b
fix test jit, it is literally close
haytham2597 Oct 18, 2024
197c1e4
Test and some improve on autocast
haytham2597 Oct 19, 2024
061ec44
cross between tensors, improve grad scaler and add normalize #1382
haytham2597 Oct 21, 2024
851a09e
GELU approximate #1368
haytham2597 Oct 21, 2024
16aba79
Device Properties #462
haytham2597 Oct 21, 2024
441bbdd
tensor backward function signature #1376
haytham2597 Oct 21, 2024
194a1f0
Half, Bfloat16
haytham2597 Oct 21, 2024
63da9c2
some fix THSCuda
haytham2597 Oct 25, 2024
ce679e2
fast copy tensor accessor
haytham2597 Oct 25, 2024
958a187
rollback sln
haytham2597 Oct 25, 2024
abe9990
Merge branch 'main' into fast_tensor_accesor
NiklasGustafsson Oct 25, 2024
0b20f13
Numel
haytham2597 Oct 25, 2024
7df8e46
Merge branch 'fast_tensor_accesor' of https://github.com/haytham2597/…
haytham2597 Oct 25, 2024
1aa1f25
original sln and fix some issue
haytham2597 Oct 26, 2024
572bc3e
some
haytham2597 Oct 28, 2024
2c33985
Test and fix some error
haytham2597 Nov 1, 2024
5a6240c
trying fix comp THSCuda
haytham2597 Nov 4, 2024
0d7a585
updage
haytham2597 Feb 14, 2025
e524239
custom libtorch fullpatch
haytham2597 Feb 15, 2025
8f35385
some update
haytham2597 Mar 26, 2025
05c7efb
imprv
haytham2597 Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
AMPManager
  • Loading branch information
haytham2597 committed Jul 19, 2024
commit 3c42a87bf4770d04fda2f67fc7ce1bca826b5598
89 changes: 89 additions & 0 deletions src/TorchSharp/Amp/AMPManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Google.Protobuf.WellKnownTypes;
using TorchSharp.PInvoke;
using TorchSharp.Utils;

namespace TorchSharp.Amp
{
public class AMPManager : IDisposable
{
//TODO: Make Singleton THREADSAFE
public UnorderedMap<IntPtr, torch.ScalarType> TensorPtrs;
private readonly AutocastMode autocastMode = AutocastMode.GetInstance();

private AMPManager() { }

public bool IsEnabled => autocastMode.Enabled;
private static AMPManager Instance;
//bool disposedValue;

public static AMPManager GetInstance()
{
return Instance ??= new AMPManager();
}

private void To(IntPtr ptr, torch.ScalarType type)
{
var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
if (res == IntPtr.Zero)
torch.CheckForErrors();
}
private void Revert()
{
using (var enumer = TensorPtrs.GetEnumerator())
while (enumer.MoveNext())
To(enumer.Current.Key, enumer.Current.Value);
TensorPtrs.Clear(); //Or should use Stack for POP?? May better performance and better ram usage
}

public void Add(IntPtr ptr)
{
if (!autocastMode.Enabled) {

if (TensorPtrs.ContainsKey(ptr))
To(ptr, TensorPtrs[ptr]);
return;
}

TensorPtrs[ptr] = (torch.ScalarType)NativeMethods.THSTensor_type(ptr);
To(ptr, autocastMode.GetFastType()); //TODO: Set scalar autocast
}

public IDisposable Enter()
{
return null;
}
protected virtual void Dispose(bool disposing)
{
Revert();
autocastMode.Dispose();
/*if (!disposedValue) {
if (disposing) {


// TODO: dispose managed state (managed objects)
}

// TODO: free unmanaged resources (unmanaged objects) and override finalizer
// TODO: set large fields to null
disposedValue = true;
}*/
}

// // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources
~AMPManager()
{
Dispose(false);
}

public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}
29 changes: 0 additions & 29 deletions src/TorchSharp/Amp/AutocastDisposeManager.cs

This file was deleted.

23 changes: 0 additions & 23 deletions src/TorchSharp/Amp/AutocastDisposeScope.cs

This file was deleted.

11 changes: 0 additions & 11 deletions src/TorchSharp/Amp/AutocastManager.cs

This file was deleted.

97 changes: 67 additions & 30 deletions src/TorchSharp/Amp/AutocastMode.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;

Expand All @@ -17,22 +18,33 @@ public static torch.Tensor AutoCast(this torch.Tensor input)
public sealed class AutocastMode : IDisposable
{
//NEED "Register" all tensor in scope for uncasting outer-scope
private bool Enabled, Prev;
internal bool Enabled, Prev;
//private torch.ScalarType Dtype = torch.ScalarType.Float32;
private torch.ScalarType fast_dtype = torch.ScalarType.Float32;
private torch.Device Device = new torch.Device(DeviceType.CUDA);
internal torch.ScalarType fast_dtype = torch.ScalarType.Float32;
public torch.Device Device = new torch.Device(DeviceType.CUDA);
private static AutocastMode instance;
bool disposedValue;

/*public static AutocastMode GetInstance(torch.Device dev, torch.ScalarType? dtype = null, bool enabled = true, bool? cache_enabled = null)
{
if(instance ==null)
instance = new AutocastMode(dev, dtype, enabled, cache_enabled);
return instance;
}*/
{
if(instance ==null)
instance = new AutocastMode(dev, dtype, enabled, cache_enabled);
return instance;
}*/
public static AutocastMode GetInstance()
{
return instance ??= new AutocastMode(torch.CUDA, cache_enabled:true);
}

public torch.ScalarType GetFastType()
{
var ft = torch.ScalarType.Float32;
if (Device.type == DeviceType.CUDA)
ft = torch.get_autocast_gpu_dtype();
if (Device.type == DeviceType.CPU)
ft = torch.get_autocast_cpu_dtype();
return ft;
}
private AutocastMode(torch.Device dev, torch.ScalarType? dtype = null, bool enabled=true, bool? cache_enabled = null)
{
//var la = torch.tensor(9);
Expand Down Expand Up @@ -78,32 +90,57 @@ internal torch.Tensor CastTensor(torch.Tensor tensor)
return tensor;
return tensor.to(fast_dtype, tensor.device);
}
/*public IDisposable Enter()
{

return this;
}*/
public void Dispose()
private void Dispose(bool disposing)
{
this.Enabled = false;
if (Device.type == DeviceType.CUDA) {
if(torch.autocast_decrement_nesting() == 0)
torch.clear_autocast_cache();
torch.set_autocast_gpu_dtype(this.fast_dtype);
//torch.set_autocast_enabled(this.Prev);
torch.set_autocast_enabled(false);
torch.set_autocast_cache_enabled(false);
}
if (!disposedValue) {
if (disposing) {

if (Device.type == DeviceType.CPU) {
if (torch.autocast_decrement_nesting() == 0)
torch.clear_autocast_cache();
//torch.set_autocast_enabled(this.Prev);
torch.set_autocast_cpu_dtype(this.fast_dtype);
torch.set_autocast_enabled(false);
torch.set_autocast_cache_enabled(false);
this.Enabled = false;
if (Device.type == DeviceType.CUDA) {
if (torch.autocast_decrement_nesting() == 0)
torch.clear_autocast_cache();
torch.set_autocast_gpu_dtype(this.fast_dtype);
//torch.set_autocast_enabled(this.Prev);
torch.set_autocast_enabled(false);
torch.set_autocast_cache_enabled(false);
}

if (Device.type == DeviceType.CPU) {
if (torch.autocast_decrement_nesting() == 0)
torch.clear_autocast_cache();
//torch.set_autocast_enabled(this.Prev);
torch.set_autocast_cpu_dtype(this.fast_dtype);
torch.set_autocast_enabled(false);
torch.set_autocast_cache_enabled(false);
}
//throw new NotImplementedException();
// TODO: dispose managed state (managed objects)
}

// TODO: free unmanaged resources (unmanaged objects) and override finalizer
// TODO: set large fields to null
disposedValue = true;
}
//throw new NotImplementedException();
}

// // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources
// ~AutocastMode()
// {
// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
// Dispose(disposing: false);
// }

public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
/*public IDisposable Enter()
{

return this;
}*/
}
}
7 changes: 2 additions & 5 deletions src/TorchSharp/Amp/GradScaler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public class GradScaler
private bool Enabled;
private torch.Tensor _scale, _growth_tracker;
private float InitScale, GrowthFactor, BackoffFactor, GrowthInterval, InitGrowthTracker;

private Dictionary<int, Dictionary<string, object>> _per_optimizer_states = new Dictionary<int, Dictionary<string, object>>();
//https://github.com/pytorch/pytorch/blob/main/torch/amp/grad_scaler.py
public GradScaler(torch.Device dev, float init_scale = 2.0e16f, float growth_factor = 2.0f,
Expand Down Expand Up @@ -54,9 +53,9 @@ public torch.Tensor scale(torch.Tensor output)
}
private class MultiDeviceReplicator
{
private torch.Tensor master;
private readonly torch.Tensor master;

internal Dictionary<torch.Device, torch.Tensor> per_device_tensors = new Dictionary<torch.Device, torch.Tensor>();
internal readonly Dictionary<torch.Device, torch.Tensor> per_device_tensors = new Dictionary<torch.Device, torch.Tensor>();
public MultiDeviceReplicator(torch.Tensor master_tensor)
{
master = master_tensor;
Expand Down Expand Up @@ -155,8 +154,6 @@ public void unscale(torch.optim.Optimizer optimizer)
return;

check_scale_growth_tracker(nameof(unscale));


}
}
}
28 changes: 26 additions & 2 deletions src/TorchSharp/NN/Convolution/Conv1D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ namespace Modules
{
public abstract class Convolution : torch.nn.Module<Tensor, Tensor>
{
internal long _dimension, _in_channel, _out_channel, _kernel,_stride, _padding,_dilation,_groups;
internal PaddingModes _paddingModes;
internal (long, long)? _kernels, _strides, _paddings, _dilations;
internal bool _bias;
protected Convolution(IntPtr handle, IntPtr boxedHandle, long input_channels) : base(handle, boxedHandle)
{
this.input_channels = input_channels;
Expand Down Expand Up @@ -113,7 +117,17 @@ public static Conv1d Conv1d(long in_channels, long out_channels, long kernelSize
{
var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernelSize, stride, padding, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Conv1d(res, boxedHandle, in_channels).MoveModule<Conv1d>(device, dtype);
return new Conv1d(res, boxedHandle, in_channels) {
_in_channel = in_channels,
_out_channel = out_channels,
_kernel = kernelSize,
_stride = stride,
_padding = padding,
_dilation = dilation,
_paddingModes = padding_mode,
_groups = groups,
_bias = bias
}.MoveModule<Conv1d>(device, dtype);
}

/// <summary>
Expand All @@ -135,7 +149,17 @@ public static Conv1d Conv1d(long in_channels, long out_channels, long kernelSize
{
var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernelSize, stride, padding == Padding.Valid ? 0 : -1, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Conv1d(res, boxedHandle, in_channels).MoveModule<Conv1d>(device, dtype);
return new Conv1d(res, boxedHandle, in_channels) {
_in_channel = in_channels,
_out_channel = out_channels,
_kernel = kernelSize,
_stride = stride,
_padding = (long)padding,
_dilation = dilation,
_paddingModes = padding_mode,
_groups = groups,
_bias = bias
}.MoveModule<Conv1d>(device, dtype);
}

public static partial class functional
Expand Down
Loading