-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement "pico" GPT example #51
Comments
@sbrunk or anyone else. I need some assistance in this work . To code the "pico" example, I need the Embedding operator. In my branch I have added this here. I have also added comments and made sure ScalaDoc is ok (minus the math expressions). The code I am working on now, is the def apply(t: Tensor[Int64]): Tensor[D] = Tensor(nativeModule.forward(t.native)) And this is a problem because I get the error: [error] 101 |final class Embedding[D <: DType: Default](
[error] | ^
[error] |class Embedding needs to be abstract, since def apply(v1: T1): R in trait Function1 in package scala is not defined
[error] |(Note that
[error] | parameter T1 in def apply(v1: T1): R in trait Function1 in package scala does not match
[error] | parameter torch.Tensor[torch.Int64] in def apply(t: torch.Tensor[torch.Int64]): torch.Tensor[D] in class Embedding in package torch.nn.modules.embed
[error] | ) I think this is because we extend from
In other words, the apply from On a related note, is it possible to constrain the TIA |
In order to keep going I have used the following solution: def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
@targetName("apply_T_D")
def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native)) Is this ok for a final solution? |
@hmf You're right Note that @davoclavo has also added storch/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala Lines 58 to 68 in 05f7dbd
storch/core/src/main/scala/torch/nn/modules/Module.scala Lines 125 to 127 in 05f7dbd
So eventually we need to merge your solutions but for now you could also just inherit from def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native))
Right now, we're tracking only the dtype at compile time. We might add that in the future though. |
@sbrunk I have looked at the embedding class and my version its pretty close to it. Currently cannot search @davoclavo's branch, but I think I can copy and use that code (minimum set of classes with updated docs). Might be easier on your side. In the meantime if you do merge into the main branch, I will update accordingly. Ok, with you? |
Sounds good to me 👍 |
Question about cross entropy functions. IThe orgial code uses something like: import torch
import torch.nn as nn
from torch.nn import functional as F
...
loss = F.cross_entropy(logits, targets)
...
probs = F.softmax(logits, dim=-1) # (B, C) I see that we have 2 options, a function in the Loss package (does not exist yet, only binary version available) and the What are the advantages/disadvantages of using one or the other? |
PyTorch has a functional and a class/module variant for most of its nn operations. See torch.nn.functional.cross_entropy The functional variant does not contain any state, you call it directly with the tensor inputs and other arguments. The class/module variant can be initialized first with init parameters, and then later reused for different inputs. If you have modules with learnable weights/parameters, the module variant also helps you manage that state (makes it easier to update all weights of your model etc.). For stateless ops without weights, like |
Hello @hmf! awesome work on implementing Karpathy's examples. I have done some progress as well, but last month I got sidetracked with some things at work so wasn't able to prepare the code to share it. I'll leave my progress implementing some of the model building blocks here in case it is helpful in any way to you. As @sbrunk mentioned, there are some new modules implemented in PR #36 - such as (Btw, you should be able to access my branch from via the PR, or via this direct link) final case class Head[D <: FloatNN: Default](
numEmbeddings: Int,
headSize: Int,
blockSize: Int,
dropoutProb: Float
) extends TensorModule[D] {
val query = register(nn.Linear(numEmbeddings, headSize))
val key = register(nn.Linear(numEmbeddings, headSize))
val value = register(nn.Linear(numEmbeddings, headSize))
val tril = register(torch.tril(torch.ones(Seq(blockSize, blockSize))))
val dropout = register(Dropout(dropoutProb))
override def apply(input: Tensor[D]): Tensor[D] =
val Seq(batch, timeStep, channels) = input.shape // (B, T, C) (64, 256, 384) [Float32]
assert(blockSize == timeStep, "Block size must be equal to time step")
val k: Tensor[D] = key(input) // (64, 256, 64) [Float32]
val q: Tensor[D] = query(input) // (64, 256, 64) [Float32]
val v: Tensor[D] = value(input) // (64, 256, 64) [Float32]
// TODO Get rid of the `.to(dtype = q.dtype)`
val weight =
torch.matmul(q, torch.transpose(k, -2, -1)) / Tensor(Math.sqrt(channels)).to(dtype = q.dtype) // (64, 256, 256) [Float32]
val weightMasked =
weight.maskedFill(
tril(Slice(0, timeStep), Slice(0, timeStep)) == 0,
Float.NegativeInfinity
) // (64, 256, 256) [Float32]
val attention =
torch.nn.functional.softmax(weightMasked, dim = 2)(
weightMasked.dtype
) // (64, 256, 256) [Float32]
val attentionDropout = dropout(attention) // (64, 256, 256) [Float32]
val output = weight.matmul(v) // (64, 256, 64) [Float32]
output
}
final case class MultiHeadAttention[D <: FloatNN: Default](
numHeads: Int,
numEmbeddings: Int,
headSize: Int,
blockSize: Int,
dropoutProb: Float
) extends TensorModule[D] {
// Multiple heads of self-attention in parallel
val heads = register(nn.ModuleList(Range(0, numHeads).map { _ =>
Head[D](numEmbeddings, headSize, blockSize, dropoutProb)
}*))
val projection = register(nn.Linear(numHeads * headSize, numEmbeddings))
val dropout = register(Dropout(dropoutProb))
override def apply(input: Tensor[D]): Tensor[D] =
val headOutputs = heads.map { head =>
head(input)
} // (6, 64, 256, 384) [Float32]
val headOutputsConcat = torch.cat(headOutputs, dim = -1) // (64, 256, 384) [Float32]
val projectedOutput = projection(headOutputsConcat) // (64, 256, 384) [Float32]
dropout(projectedOutput) // (64, 256, 384) [Float32]
}
final case class FeedForward[D <: FloatNN: Default](numEmbeddings: Int, dropoutProb: Float)
extends TensorModule[D] {
// A simple linear layer followed by a non-linearity
val net = register(nn.Sequential(
nn.Linear(numEmbeddings, numEmbeddings * 4),
nn.ReLU(),
nn.Linear(numEmbeddings * 4, numEmbeddings),
Dropout(dropoutProb)
))
override def apply(input: Tensor[D]): Tensor[D] =
net(input)
}
final case class Block[D <: FloatNN: Default](numEmbeddings: Int, numHeads: Int, blockSize: Int, dropoutProb: Float)
extends TensorModule[D] {
// Transformer block: communication followed by computation
val headSize = numEmbeddings / numHeads // 384 / 6 = 64
val attention = register(MultiHeadAttention(numHeads, numEmbeddings, headSize, blockSize, dropoutProb))
val feedForward = register(FeedForward(numEmbeddings, dropoutProb))
val layerNorm1 = register(nn.LayerNorm(Seq(numEmbeddings)))
val layerNorm2 = register(nn.LayerNorm(Seq(numEmbeddings)))
override def apply(input: Tensor[D]): Tensor[D] =
// (64, 256, 384) [Float32]
val a = input + attention(layerNorm1(input)) // (64, 256, 384) [Float32]
val b = a + feedForward(layerNorm2(a)) // (64, 256, 384) [Float32]
b
}
final case class Dropout[D <: FloatNN: Default](probability: Float) extends TensorModule[D] {
override def apply(x: Tensor[D]): Tensor[D] =
nn.functional.dropout(x, probability)
} I'm happy to assist you in any way to get this to work. I was able to get some inference going without any runtime errors, but haven't had time to train the model using shakespeare writings yet. I will also be available to continue work on the pending PR to get it merged, in case I can help in any way @sbrunk |
Oh I forgot, there are also some changes needed for pico GPT that I haven't created a PR for, but I have fixed in my local project. I aim to get these changes submitted soon, but here they are in case you need them earlier:
def maskedFill[S <: ScalaType](mask: Tensor[Bool], value: S): Tensor[D] = Tensor(
native.masked_fill(mask.native, toScalar(value))
)
def sqrt = Tensor(native.sqrt())
def tril[D <: DType](input: Tensor[D], diagonal: Int = 0): Tensor[D] =
Tensor(torchNative.tril(input.native, diagonal.toLong)) Fixing def split[D <: DType](
input: Tensor[D],
splitSizeOrSections: Int | Seq[Int],
dim: Int = 0
): Seq[Tensor[D]] = {
val result =
splitSizeOrSections match {
case i: Int => torchNative.split(input.native, i.toLong, dim.toLong)
case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
}
(0L until result.size()).map(i => Tensor(result.get(i)).clone())
} |
@davoclavo feel free to take over #36 again if you have capacity. I've merged main into it with some improvements of the native bindings but since Scala Days is only 4 weeks away I'd like to focus on getting my Storch talk ready first. Happy to help/review etc. but I'm not sure I'll be able to actually work on it before the talk. |
@sbrunk sounds good, I'll try to polish the last remaining bits. Best of luck on the Scala Days talk! Hopefully it will be streamed/recorded, I'd love to watch it :D |
Thanks! I'm sure it will be recorded and put on youtube some time after the conference as the videos from the Seattle edition from June are already online. |
@davoclavo Thanks for the assist. Please note that at this time I am working on the very simple "video" version. My aim here is to learn about GPT. I will look at your code and incorporate all I can to make merging easier. |
Questions regarding def softmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)(
dtype: Out = input.dtype
): Tensor[Out] =
val nativeDType =
if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType)
Tensor(torchNative.softmax(input.native, dim, nativeDType)) This means that we have explicitly provide the last (usually empty) parameter so:
If we don't, we get the error: [error] 358 | val loss1 = F.crossEntropy(input1, target1)
[error] | ^^^^^^^
[error] |Found: (gpt.BiGram.target1 : torch.DType => torch.Tensor[torch.DType])
[error] |Required: torch.Tensor[O]
[error] |
[error] |where: O is a type variable with constraint <: torch.NumericRealNN I have made that last parameter an implicit. I did the same for The original Python example code uses a TIA |
That's fine but could you give the following variant a try? It's a solution we already use in other places and avoids both implicits and multiple parameter lists (at the expense of a slightly more verbose type signature). import Derive.derive
// ...
def softmax[In <: DType, Out <: FloatNN | Derive](
input: Tensor[In],
dim: Long,
dtype: Out = derive
): Tensor[DTypeOrDeriveFromTensor[In, Out]] =
val derivedDType = dtype match
case _: Derive => input.dtype
case d: DType => d
val nativeDType =
if dtype == input.dtype then ScalarTypeOptional()
else ScalarTypeOptional(derivedDType.toScalarType)
Tensor(torchNative.softmax(input.native, dim, nativeDType))
}
Yes, you can add it as a regular method in |
Done (also for
Done: def shape: Seq[Int] = size
def softmax[Out <: FloatNN | Derive](
dim: Long,
dtype: Out = derive
): Tensor[DTypeOrDeriveFromTensor[D, Out]] = F.softmax(input = this, dim = dim, dtype = dtype)
def square = Tensor(native.square()) |
While trying to replicate the Colaboratory notebook to check the code is working, I tried to do the following: // We want x[b,t] = mean_{i<=t} x[b,i]
val xbow = torch.zeros(Seq(b0, t0, c0))
for b <- 0 until b0
do
for t <- 0 until t0
do
val xprev = x(b,º`:`t+1) // (t,C)
xbow(b,t) = torch.mean(xprev, 0) The TIA |
The C++ API has a method for assigning values (with indices): See https://pytorch.org/cppdocs/notes/tensor_indexing.html#setter #53 should add support for it. Could you give it a try? |
Found some compiler weirdness with the changes above.These do not compile: xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0)
xbow(Seq(b,t)) = torch.mean(xprev, dim=0)
The error is: method mean in trait ReductionOps: (input: torch.Tensor[?], dtype: torch.Float32): torch.Tensor[torch.Float32] does not have a parameter dim and (for the last one): Found: (0 : Int)
Required: torch.Float32 But these do: xbow(b,t) += torch.mean(xprev, dim=0)
val c = torch.mean(xprev, dim=0)
xbow(Seq(b,t)) = c
xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true, float32)
xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true) Maybe some tweaking of the 1st definition may get it working, but seems like a Scala issue. |
It looks like the compiler gets confused by the overloaded variants of I realized that the default |
@sbrunk Changes work fine. Thanks. |
I need the use of Dropout. In Python this seems to return a constructor of sorts (did not check), which can then be applied to a I see that we have a
EDIT 1: @davoclavo I realized you have already defined |
I would like to use register_buffer. According to the Python API doc, we must pass in a name. Looking at the public Tensor register_buffer(BytePointer name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString BytePointer name, @ByVal Tensor tensor);
public Tensor register_buffer(String name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString String name, @ByVal Tensor tensor); So in def registerB[D <: DType](n: String, t: Tensor[D]): Tensor[D] =
nativeModule.register_buffer(n, t.native)
t However, as an example: def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
name: sourcecode.Name
): Tensor[D] =
nativeModule.register_parameter(name.value, t.native, requiresGrad)
t the name is implicitly defined. Is there any way I can keep the implicit but still allow manually setting that name? On a related not, shouldn't these functions return a EDIT 1: we also have the problem of duplicate overload methods due to the use of defaults. What is the way to solve this here? Can I change the names? EDIT 2: In the meantime I will use: def buffer[D <: DType](t: Tensor[D], n: String="")(using
name: sourcecode.Name
): Tensor[D] =
val name_ = if n.trim().isEmpty() then name.value else n.trim()
Tensor( nativeModule.register_buffer(n, t.native) ) TIA |
I think what you found is the torch.nn.functional.dropout(input=torch.rand(Seq(3,3)))
// res2: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CPU
// [[0,4759, 1,4497, 1,7002],
// [1,2299, 0,0000, 1,1805],
// [0,0000, 0,0000, 0,0000]] It corresponds to torch.nn.functional.dropout in Python. Seems like we're still missing the module variant of |
We could add an explicit optional name parameter, i.e. defaulting to an empty string, or using an
You're right, it's better to use the tensor returned by the native register method.
Yes please go ahead. Perhaps we can keep
👍 |
Hi @hmf ! Apologies for the confusion, I have not committed my changes yet, as I have a bunch of other stuff that needs to be cleaned up. I just shared them in my previous comment to partially share the progress in case it was useful to you :) You should be able to either drop in that code I shared in your script/example, or add it as a new module to storch. I'll keep my ear open in case you need any further help, and hopefully find some time soon to help out to contribute these modules to storch. |
While trying to implement and debug the multi-head attention mechanism, I have what seems to be unexpected behavior. For a model with the multi-head "only", the code: val nuParams = m.parameters.map(_.numel).sum
println(s"${nuParams} parameters") Reports:
Now to this model I add the following layer: val ffwd = register( FeedFoward(nEmbed) ) where
Shouldn't that be 4481 + 1056? TIA |
@hmf I have a hunch (not tested). Could you try to wrap your storch/examples/src/main/scala/gpt/BiGram.scala Lines 1316 to 1326 in 5e1fdf2
- val net = nn.Sequential(
+ val net = register(nn.Sequential( Right now it's registering the layers inside |
@sbrunk I have confirmed that I need to register the inner modules. As for the macro, maybe a single function that traverses the sub-modules and registers them would do. But we also have parameter and buffer registering, so that would also have to dealt with. Thanks. |
@davoclavo Thanks for the feedback. In regards to the seed I have tried to do as is in the video, but this may not be correct. Also some of the constants may be off. For the final version I will try to replicate this code, so general performance should match. As for the causes of differences in loss, when I tested the MNIST example in Linux, its behavior was not the same as in Mac. In fact the process did not converge. This was strange. @sbrunk changed the code altering the learning rate so that learning would converge in both OS. |
@hmf I'll try to look into this to get a better understanding. Is there anything I need to consider if I want to run it? I.e. I've seen you are using mill, right? |
@sbrunk Thanks.
Not really. Simple Scala object. Messy code though. Sorry about that.
Correct, but it just calls the main. Execution is in the object initialization. Something to correct. I was hoping to contribute the Mill script (another issue). It is just missing project publishing. When I get time I want to upgrade it to the latest Laika version to avoid the need to override the Helium templates (currently it overrides the header). |
I have implemented a clean version of this Python code. It is here. I am able to get a validation error below 2.0 (even less) as shown in the tutorial video. However with an increased number of iterations. Unfortunately I am unable to use the exact same parameters due to memory issues. I am using a GPU with a whopping 24 Gigabytes. As soon as I start training, CUDA ( I have looked for the APIs but cannot find the calls to get the CUDA memory stats. Can anyone give me some pointer on how to check were the memory is used and diagnose this issue? TIA |
@hmf yeah right now, we don't really have a good way to do memory profiling. Need to look into that too. Perhaps we can use the JavaCPP Cuda bindings to get better GPU memory usage information. One idea you could try for now is to run only parts of the model (i.e. just the attention layer etc.) inside a training loop (you can use just random inputs of the right size). That might help to isolate better what part consumes so much memory or where it leaks. |
@sbrunk thanks for the suggestions. I have started using a kludge to try and get an idea where the memory is being allocated. What I do is set a
I was hoping the What JavaCPP Cuda bindings are you referring to? I quick look at the API does not reveal too much. I also think that we would need to access this information in a device independent manner. Maybe via Device? EDIT: Device does not seem to be helpful. Need to look at torch.cuda.memory_stats for possible solution. |
Does it allocate too much inside a single iteration already or does it grow over multiple iterations during the training loop?
In Storch itself, I think we only have the image classifier example and #5, but you seem to be already doing it this way. The JavaCPP tests for deallocation and
The Java bindings to the CUDA toolkit itself. But that's a long shot, I'm not sure if it provides something usable for us here.
It looks like LibTorch provides something like this, see: https://discuss.pytorch.org/t/libtorch-equivalent-of-torch-cuda-memory-reserved/165995 |
It grows as it iterates.
I had seen these already. What I have learned is that one of the functions (that calculated the validation and training loss was accumulating memory. I added another
Ok. I agree with you.
Ok. |
The current So what is the best way forward here? Should we include such a method? What should we name it? Do simply iterate through all layers and apply a function like Python does? Should I open a new issue to discuss this? TIA EDIT: above it should read "The current |
Good idea. A recursive apply like in Python should be quite useful. And yes, please create a new issue for this. |
Results of v2 on par with tutorial (loss below 2.0), but slower convergence. After a while it diverges. At the end I show an example of its output.
Output (removed initial white spaces, too many):
|
@hmf amazing work! I got hands on a larger GPU now and will start playing with it. |
Is it possible to get a code for this example? It would be an excellent example project to learn from if it were posted. |
@sbrunk I have created a clean branch with the changes that implement the example. This holds changes for #51 and #61 (apologies for the comment error in the commit).
I now have another issue. After updating to your latest changes (2.1.1), GPU dos not work on my side. This is also true for the original LeNet example (nvidia-smi shows no process using GPU). Note that because of this, current code breaks, but it is not usable without a GPU. EDIT: forgot to mention that with version 2.1.0 I had some memory issues I did not have with the previous version. In particular I implemented a wrapper for memory_stats. I think it is necessary to add the memory management functions (such as clearing the cache) to allow us to use storch effectively. Could check this and give me feedback? TIA |
Thanks @hmf for pushing this forward. Could you create a PR from your branch? That should make it easier to do reviewing. I'll try to figure out the new GPU issue. |
@sbrunk The code assumes a GPU is available. The The class is |
I can't reproduce the GPU issue at the moment, but I could only try on an RTX 4090 so far, which is ADA architecture, while 3090 is Ampere. It did work with 2.1.0 for you right? Could you give at a try with the latest update to PyTorch 2.1.2 by bumping the PyTorch patch version in - val pytorchVersion = "2.1.1"
+ val pytorchVersion = "2.1.2" |
GPU working with 2.1.2 Note that with:
In sbt I now get:
But no issues with 2.1.2. Thanks. EDIT: yes it is working with 2.1.0 |
@hmf I'm trying to reproduce your results but it's diverging much faster in my case. I get down to a train loss of 2.3 but then it starts to go up again. At some point, the losses even go NaN.
Here's how I ran it: Took the branch, disabled weight init, ran with a learning rate of 1e-4. // modules.foreach(init_weights)
// ...
train(model, 1e-4, 67000) Any idea what could be different? |
@sbrunk I am trying to rerun the test to confirm all is Ok. Unfortunately I have had to recreate the dev container. I have also merged your latest changes from main. I am now running with initialization off. As soon as I get some results I will report back.
At this time, it only occurs to me that the OS libraries (including NVIDIA's stuff) my be different. I am assuming your are using Linux Ubuntu. Below is a list of the setup. We could also try setting up a fixed random number seed for replication. |
@sbrunk I have rerun with Pytorch 2.1.2. and get:
Here is the full output: I have added:
and am running this again. We can then test and compare with this seed.
I have noticed that your compute time per iteration is about half of mine (0.123 vs 0.061). Nice 8-) EDIT: this run resulted in an abrupt divergence with a resulting NAN. Can you check that you get the same output? Here is the output I get: |
@hmf I did a run on d8d75b7 and it's much closer to your result than before: There are still numeric differences but I guess that could be due to slightly different hardware/driver. |
@sbrunk I am somewhat skeptical of the results. I noticed that the first loss is indeed the same value (baring precision errors). For that reason I would expect the next values to be the same (the data is the same and should be loaded in the same order). This also does not bode well for unit tests - something that can be added that can be used to confirm your hypothesis. Having said this, I am not satisfied with the results. I am considering a ViT implementation that has SOTA performance that can be compared. This case is easier to test than NLP because pre processing should be simpler. Just a thought. |
That's true. It seems like I get reproducible results on the same hardware/setup though when I run it twice. I'll have access to yet another GPU type next week and I'll try it there too for comparison.
That's a great idea. If you'll give it a try, let me know if I can help you in any way. |
Response to request in issue #44.
Attempt to rewrite the "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" in storch.
The text was updated successfully, but these errors were encountered: