-
-
Notifications
You must be signed in to change notification settings - Fork 27
Superior 3INST parameters #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Swapped out original LCG params for computationally screened optimal MCG param. 0.2-3.9% lower PPL 1.25% faster processing (addition removed)
This is very interesting. I will have to do a lot of tests, also on larger models, just to thoroughly verify this. Because to actually include it would require doubling the number of kernel instances and adding some metadata to quantized models to identify the new codebook (a bit too far along to just invalidate all existing EXL3 models.) I'm surprised there's a performance difference, but it looks like the compiler does produce an extra MOV instruction in the first case since 64248484 is more than 20 bits, and the GEMM kernel is currently running into an ALU bottleneck due to the trellis unpacking overhead (on Ampere mostly). Also, was the 1.25% speed increase just the kernel latency or inference overall, and what hardware did you see this on? |
More testing today with larger model:
Few minor regressions. Mostly breakeven. Biggest gain on 3.0bpw which already appears to be the format's new sweet spot. Improved parameter is most valuable if:
Happy to run further evals. Which ones would be most helpful? Is there a KLD eval? Or another automated benchmark you find correlates well with actual model quality?
Just the decoder's isolated CUDA routines in a vacuum.
2 x 5090. Perhaps older hardware has larger speedup from improving the decoder? |
Cutting out a MOV instruction is definitely worth it, even if it ends up breaking even on accuracy. I think if the difference is this small on perplexity, most likely it will be hard to detect any other way. I didn't write up a KLD test yet, but eval/model_diff.py would be a good place to start. I'm a little tied up with Qwen3 at the moment, but I will get back to working out a nice way to incorporate this. |
More evidence for 0xFF7A83E5Out of curiosity, I coded my own routines to reproduce the QTIP paper's distortion rate calculations. Got slightly different results than what they reported:
First 3 are nitpicks... essentially just replicating their work with 1 more sig fig.
However 3INST is more confusing:
Their code shows they tested another decoder called 2MAD. My guess is 2MAD has a measured distortion rate of 0.069 too and they were initially planning to report both 1MAD and 2MAD results. But then they found 3INST and probably decided that since 2MAD was just slower without being measurably better, they would instead only report 1MAD and 3INST data. But then somehow the data for 3INST distortion rate never got updated in Table 1 of the paper. In any case, I've sampled the richest chunk of the MCG multiplier space and using abstract MSE measurements (rather than full end-to-end model quantizing / perplexity measures), 0xFF7A83E5 once again ranks at the top with the lowest avg distortion rate. 3INST (MCG-0xFF7A83E5): 0.0726 (0.0017 lower distortion rate than original 3INST params) Next 3 best 3INST MCG multipliers:
So far only tested ~10% of the spectrally good MCG multipliers from Vigna's work for MSE improvements. Should complete exhaustive rankings in next couple days. But it's good seeing a totally different evaluation criteria that uses completely separate code also identify the same exact multiplier as the best one. And relative to the original 3INST baseline, this multiplier produces 2.3% lower distortion which is roughly the average improvement I'm seeing in 2-bit model perplexity too. So this is a good sanity check that this particular parameter isn't just "getting lucky" on the specific models I've tested, but actually would be expected to be 2.3% better at 2-bits. MSE distortion rate at k=1 bit, 3 bits, 4 bits, and 5 bits all improved as well at roughly 3.3%, 2.0%, 1.6%, and 1.2% vs baselines. LOP3 not currently used?QTIP paper mentions how 3INST is especially good precisely because mask and xor step can be combined into one low level lop3 instruction. So theoretically nvcc should render these two lines: x &= 0b10001111111111111000111111111111u; Down into something like this PTX instruction: asm("lop3.b32 %0, %1, 0x8FFF8FFF, 0x3B603B60, 0x6A;" : "=r"(x) : "r"(x)); But nvcc never does this (at least on my system with sm120). Maybe it's not kosher to pass in the same register for an input and output with lop3? If so, changing CUDA code from "x ^= ...". into "y = x ^ ..." might be enough to fix this and help the compiler successfully find this optimization. I haven't found a painless way to reliably implement LOP3 yet but it's probably a trivial patch for a smart CUDA wizard. We have any of those around here? If LOP3 isn't getting implemented in any of the EXL3 decoder kernels and it's using 2 separate [and (...) and xor(...)] instruction every pass, that's another 8-10% speedup being left on the table. Maybe this will finally make Ampere perform better? There might also be another clever way to shave off a mov in the final re-packing steps by ordering it slightly differently as well. But we can explore that later. It's the highest hanging fruit and can be optimized down any time. Removing addition, using the right multiplier, and making sure LOP3 gets implemented are all more important for getting better speed and quality. Let me know when KLD eval is ready so we can triple verify this change actually provides 1-3% avg higher fidelity via lower MSE, lower perplexity, and lower KLD at all bitrates (especially lower ones). The speedup may grow to +10% too once lop3 is used. |
Top 6 3INST MCG parameters found through exhaustive MSE sampling of all Vigna 32-bit multipliers (plotted with your visualization and RMS calculation): These plots are a fine tool for getting a general feel for each codebook by examining part of its distribution. But do keep in mind that this is only plotting raw input <--> output for dim=1 for 8 different bitrates [1-bit - 8-bit]. However, the QTIP Trellis codes operate in 256 dimensions. So these charts are only showing 1 of the 256 distributions that effect overall performance. Also, unlike most applications where dim=1 is by far the most meaningful distribution, there's nothing special about dim=1 within a 256 dimension Trellis Coder. So it should be expected that quality of the 255 higher dimensions (aka higher-order spectral quality) actually matter equally in all dimensions for this application. Just pointing this out so we remember RMS=1.17 vs RMS=1.20 in these charts just means ~0.4% of all distributions are ~2.6% better. But that doesn't automatically mean the other 99.6% of the distribution quality will also be 2.6% better. Often higher order behavior is quite jagged and random and good distributions in one dimension only loosely correlate with distribution quality within other dimensions. That's why I tested these six 3INST MCG values for actual MSE distortion using real 256 dimension Trellis Coding: Y-axis re-scaled to distortion relative to an optimal codebook [2^(-2*bitrate)] Average distortion and variance was plotted by sampling each decoder with 145-195 different 8192 [16 bit] Gaussian data chunks. This effectively simulates the precise error that occurs when quantizing and decoding 100+ of the most difficult to quantize tensors that it's possible to construct. This is the standard model (worst-case) for distortion rate calculations. It's broadly assumed that codebooks which are strongest in this worst-case setting can "degrade gracefully" into also being the most robust even in less-challenging settings. One of those less-challenging settings is modern LLMs weights which empirically only have ~76%-79% entropy levels (lower-bound = Llama 2; upper-bound = Llama 3). Modern LLMs haven't reached 100% entropy yet (and possibly won't/can't within the current transformer paradigm). In any case, just pointing out in passing that these codebooks I'm locating are optimized to minimize distortion in a slightly more challenging domain than what we intend to use them for here. As for the data, my hunch that 1MAD had higher variance was right. That said, it's range still looks better by most metrics (max distortion, avg distortion, etc) than even improved 3INST decoders at 2-bit and 3-bit. However, the variance in the distortion gets worse at higher bitwidths and may become intolerable. Is there a reason 3INST was initially chosen over 1MAD? I know ik_llama.cpp also chose 3INST for their experimental IQx_KT quants so I assume there's a good reason? It's just weird because in theory I'd expect 1MAD to both run faster and provide better results on average (below 4-bits). Is it just the inconsistency (higher variance) in 1MAD quantization quality that was the problem? If that's all that's wrong, I could try to find better params for 1MAD to reduce distortion and variance some more. Default 1MAD LCG params certainly aren't optimal either. Would that make it a more attractive option again? Current 3INST decoder is only giving:
SUMMARY:
|
Still held up by other stuff. Currently completely refactoring the kernels and probably the next thing I'll be occupied with will be kernel fusion, an alternative GEMV path and maybe Ampere optimizations. It's just a bad time right now to also have the complexity of a second quantization format to worry about, though there's no reason it couldn't be plugged in a little later. I'll write a quick KLD test in a little bit, at any rate. As for compiling two steps into a single LOP3, the compiler may simply be choosing not to do it because it isn't efficient. MOV+LOP3 should have half the throughput of LOP3 on its own, but it may have lower latency overall, and scheduling the MOV instruction might be free if there's some other pipe that's stalled at that point anyway. It's very hard to know exactly what the compiler's deal is sometimes since SASS documentation is so sparse. A point about the codebook worth noting, perhaps, is that the quantizer doesn't always end up using all of it. This may be due to insufficient regularization (128D Hadamard rotations for the sake of kernel fusion and the ability to split models for tensor-parallelism at load time,) but especially at higher bitrates where there is more coverage, it ends up being preferable to scale down the input by a factor of 10-40% to make better use of the denser middle part of the distribution (and possibly widen that "spike" that always appears in the middle?) I was initially testing with 1MAD and did see promising results. The distribution was a perfectly smooth Gaussian too, albeit with some gaps and very bad correlations at 1 bpw. I'll see if I can dig up the results in a little bit, or recreate them. Paper shows this plot for 2 bpw, with narrow bands that end up not really mattering: I definitely didn't explore this fully, since there was also a whole framework to build around the quantization algorithm. The overall idea with the project is to make these algorithms accessible in a usable format. And QTIP, despite being SOTA, still remains largely unavailable with only a handful of quantized models on HF (all made in relation to the paper it seems) that I still can't get working despite quite a bit of effort. So I had to make a decision eventually or be stuck endlessly obsessing over the details. That ultimately came down to 3INST achieving better perplexity on actual models in my tests, though not by much, and looking more effcient at a glance. Also the QTIP paper describe the two methods as roughly equivelent, with a slight edge to 3INST at 2 bpw: Can't say for sure if this was the same reason the llama.cpp people chose 3INST for their experiments, but at least when it comes to performance I'm not convinced 1MAD would end up being faster. It's essentially: 1MAD:
3INST:
At least on the surface, 3INST seems more efficient. I believe PRMT executes on the SFU pipeline and LOP3 has 4x the throughput (?), and also the float conversion takes 4 cycles or something. The real bottleneck is unpacking the trellis, though. You end up needing a lot of registers, you want to stick to 32-bit SMEM access which means reading any 16-bit field requires loading two 32-bit values, doing a funnel shift and then masking (would be interesting maybe to explore LCGs/MCGs that are indifferent to the high 16 bits of the input.) All that said it's possible 1MAD (or some other function) could be or could be made faster at the end of the day. That's definitely worth exploring. And if indeed there's something on the order of 0.5 bpw to be gained, even in theory, that needs to be looked into. It feels unlikely, though, noting that EXL3 already matches (sometimes outperforms) methods that rely on finetuned codebooks. Wish I could get QTIP inference working in Transformers. Then I could just plot the finetuned hybrid code models onto the same graph and get a better idea. :) 🤷 |
Like you, I also can't make actual QTIP code work. 😂 The requirements file is a lie. There's no combination of torch, cuda, numpy, fast-hadamard-transform, transformers, and qtip that can compile simultaneously. I gave up after trying every public fork of fast-hadamard-transform, creating my own fork that compiles in more than zero versions of torch+cuda, only to still not have it work. I love QTIP (mathematically) but their research code is diabolical. The fact I can't make it run in the era of infinite LLM assistance is mind-blowing. Their paper mentions using using vabsdiff4 to do the 4 adds -- which can be emulated with PRMT -- or achieved using other neat methods people found a few years ago. Of course, even this random post about VABSDIFF4 incidentally cites Vigna. We should really just ask Vigna what the best 256-dimensional static codebook function is. You know it's already on a napkin on his desk and it's only 2 instructions somehow. Got some prettier looking 1MAD distributions even a few weeks ago when I first looked into this:
That's VERY interesting! What file does the EXL3 code make these sorts of scaling decisions in? I assumed in practice that rotations would probably limit codebooks from ever reaching 100% usage, but the fact you're able to scale inputs going into a TC and book a "profit" in reduced distortion is a stronger indictment on the default codebook params than anything I've found so far. Maybe I could add some verbose output to the quantizer to track how often and how aggressively this sort of scaling hack wins out over the default full codebook? That's probably an excellent segregate measure of absolute cookbook quality. I'll try to tighten up all these threads into more concrete answers later this week. It just takes time to fully characterize param space for these computed codebooks. Appreciate you need to stay on task with other components too so thanks for giving this attention when you have time. I'll keep chugging along and we'll sort this all out soon. |
All the scaling and quant logic (aside from the Viterbi kernel and some other CUDA support functions) is in quantize.py. I tried to keep it neat but it always becomes a little messy when it needs to also work in less than an unlimited amount of VRAM. And of course every new model throws a few surprises at you. So 🤷 Main function is quantize_exl3, and regularize is the function that tries to make the input tensor as Gaussian as possible. The input and output channel scales were added later on in development. They're not necessary for all models but in a few cases the Hadamard rotations just aren't enough to deal with extreme outlier channels. This might not be an issue if I was rotating full rows and columns, but doing so would both be less efficient and as mentioned result in tensors that can't be split for tensor parallelism at load time. You'd have to requantize the model for any given hardware configuration and I'm trying to avoid that. I did a bunch of tests at a range of scales but couldn't find a way to predict the best scale for any given tensor at a given bitrate. So I ended up with a golden section search on a sample of the regularized input tensor, which should be solid under a couple of assumptions:
At least for the last assumption I did do some plots (blue is 2 bpw quantization and red is 5 bpw): I think it's very likely scaling wouldn't improve anything if the regularized tensor was perfectly IID Gaussian, but it generally isn't. There may be outliers that simply can't be encoded and need to be clamped, for instance, so downscaling the tensor reduces the error from those outliers. And the higher the bitrate, the less penalty you incur from the mismatched distributions, so you end up with this tradeoff. It's coincidental that 1.0 is the ideal scale for 2 bpw in this example. It all shifts around a bit depending on how amenable a given tensor is to the regularization process, or maybe how Gaussian it was to begin with. |
I added a KLD test to model_diff.py. To run it: python eval/model_diff.py -ma /mnt/models/test_model -mb /mnt/models/ref_model -r 5 This would do 5 rows of 2048 wiki2 tokens. You can do more rows of course, but you end up with two big (num_rows, 2048, vocab_size) tensors to manage and it's not very clever about that. Example output:
|
Fantastic! I'll try out KLD eval and collect data on various codebook params across different models. Looks like your quantization function has lots of good instrumentation in it to gauge relative quality across different codebook functions too. Thanks for sharing more details about how you constructed the regularization function. Very interesting! Trying hard to wrap my head around it. Forgive me if I'm not grasping all of the nuances right away. It just seems to violates so much of what I thought I knew about Hadamard rotations. My initial thoughts are something like:
|
I did try to rescale channels after rotation but couldn't get good results from it. It's also less efficient since you'd still need the separate input/output sign flips, and this way you can combine them into a single operation. I think a better solution might be to scale each 128x128 block of the regularized tensor independently. But this would require some modifications to the GEMM kernel, followed by days and days of testing to make sure this new method still works across some large enough selection of models. And all the while people will be begging me to add support for this or that new architecture, make it faster, support ROCm, when is multimodal ready, etc. I just have to prioritize and build on what works instead of going back to the drawing board every other week.
This might be some sort of overflow, I'm not sure yet.
|
|
I'm pretty sure the Hadamard rotations function correctly because I implemented them in several different ways and they're mutually compatible. The quantizer uses a standard 128x128 matrix (from Sylvester's method) and a Torch matmul, and the kernels use warp shuffling shenanigans. In all combinations H^T@H = I, etc. And I'm not so much saying that outliers "slip by", but rather that they still end up dominating in a few extreme cases. So for an extreme example, if you have a 4D vector like (1, 1, 1, 1000), the rotation is going to be (501.5, -499.5, -499.5, 499.5). Quantizing to a grid after that you're still going to lose the weaker signal to rounding. Or, if you scale it for quantization with some Gaussian-ish codebook, all of the rotated values end up being in the tails of the distribution, because the single outlier is added to or subtracted from every other channel.
in_channel_scales = block_rms(weight, dim = 1, keepdim = True)
su = (su * in_channel_scales / (-codebook_scale) + 1e-10).float() # mustn't be inplace
weight /= su -> su = (su / (-codebook_scale) + 1e-10).float()
weight /= su TBH I've long since forgotten why it flips all the signs again there, and I can't imagine why it's needed. Adding the small eps seems pontless too. :D
|
Yep, turns out the compiler just doesn't figure out how to combine the two ops into a LOP3. Free small performance increase with some inline PTX. Go figure. (: |
Initial KLD (forward pass) +PPL (10 rows; initial PPL data at start of thread is 100). These plots show performance with quantize.py codebook_scale = real std dev (RMS) Switching to % KLD and % PPL to better see differences. Also removed 2 lowest performing multipliers: Hard to pick clear winner from L3.2 1B alone. But encouraging to see KLD dropping 4-8% with most multipliers at all bitrates. 0xCAF6A435 looks promising. If these KLD improvements are actually indicative of:
REMAINING TODO:
|
0x00D0AA65 testing showed less improvements + more regressions than the best multipliers. It's the only good 24-bit multiplier so in theory a sicko could: asm volitile("mul24.lo.u32 %0, %1, 0x00D0AA65;" ... Not seeing speedup on Blackwell from this. Supposedly mul24 isn't faster anymore on modern cards. So unless you see big speed boosts on Ampere, there's no compelling reason to consider 0x00D0AA65 over the other options. 70B: 0xCAF6A435 looks best. Some merit in considering 0xFF7A83E5 which typically has better 2bit and 3bit performance. 0xCAF6A435 is more well-rounded at all bitrates though. 1.0bit and 1.5bit look much better with 0xCAF6A435. KLD regression at 5 bit is only 0.0001. On other hand, improvement at 1 bit was KLD: 4.82 --> 4.61 PPL: 1506 --> 1126. Not implying sub-2bit quants will be good now, but for narrow use cases where 1.8bpw quants fits a specific GPU but 2.0bpw can't, 0xCAF6A435 degrades more gracefully than default 3INST. Can you start testing 0xCAF6A435 to confirm it's a good candidate? I think either that or 0xFF7A83E5 would be best. Remember to update codebook_scale = std dev in quantize.py for best performance. 0xCAF6A435 = 1.206441 |
Currently everything's tied up testing block-sparse layers. I've got scripts queued for the next several days just quantizing and testing. Not looking forward to the electricity bill :| As for At any rate, the main thing I'm worried about right now is compilation time (of all things). Currently there are 128 unique instances of the GEMM kernel, and I'm exploring ways to make it more flexible without doubling that number every time a new feature is added. But there's no reason in principle the MCG parameter couldn't be variable at runtime, I'll know more soon. |
QTIP paper overstated 2-bit RPTC distortion rate too.
Real L=16 RPTC distortion rate from original 2010 RPTC citation was 0.06595 (or ~0.066) [calculated based off final table in paper]. This is 36% closer to the theoretical optimum so not a small difference! If 0.068 distortion were correct, it shouldn't be possible to make computed codebooks much better than standard 1MAD. But since the 2-bit RPTC limit is actually below 0.066 there are plausibly 1MAD params a few % better. And I've found a few! These don't use an additive factors either -- should they be called 1MUL now? Average quality improvements with better 1MAD params are small but reduction in overall variability and worst-case distortion is a bigger deal. Makes it higher quality at most bitrates with less compromises and regressions than standard 1MAD. Also removes addition step. Honestly I can't measure any speed difference between 1MAD vs 3INST. Perhaps my final multiplication/conversion step was more efficient? If 1MAD and 3INST really are equally fast, it could make sense to primarily use 1MAD instead. Default 1MAD was more marginal but these alternate 1MAD constructions outperform 3INST in theoretical MSE reduction (at least up to 4 bit quants). At 5-8 bit 3INST technically has lower variance and could be better in worst case scenarios like trying to encode pure noise but for Hadamard-rotated, 80% entropy LLM weights, there's no daylight between their actual performance. Quick test runs with Llama 3.2 1B and Mistral 7B showed lower PPL at all bitrates. I'll test Llama 3.1 70B next. Any chance the smoother Gaussian of 0xAD9A2EC5 1MAD is enough to reduce or eliminate the Command-R-Plus outlier you're currently struggling with using regular 3INST? Could you dump that one 128x128 block of Command-R-Plus with that outlier to a file (post-rotation)? Right now I'm using a completely synthetic memoryless Gaussian source (read: pure noise) as the heart of my multiplier ranking pipeline. Perhaps instead I should rank decoder multipliers against the actual catastrophic distributions they need to be most resilient to -- outlier spikes (presumably from attention sinks?) washing out variability in certain blocks post-rotation. BTW - this uses 0.0677 as final multiplier (implied divisor of 147.71 as opposed to the 148.80 in QTIP paper or the (implied) 118.84 in current codebook.cuh 2MAD code |
Okay, so I'm working on making it switchable. So far I've added MCG mode as an option to the quantizer script, with an arbitrary multiplier. This way the framework handles both, and I can make some comparisons relatively easily. I'm not seeing a clear winner, though. Here's perplexity for L3-1B: And KL divergence: The latter is definitely the better measure to be focusing on, since it measures difference to the output of the unquantized model directly. Perplexity is simpler to measure but it can be misleading, e.g. by allowing for the quantized model to score better than the original (and I've seen this in a few cases with very large MoE models.) Now, it's very hard to pick a clear winner from these. Even the original LCG wins at 2.25 bpw. That's produced as a mix of 2, 3 and 4 bpw tensors, all of which work better individually with the 0xCAF6A435 variant, which kinda says to me there's some chaotic dynamics going on and it's going to be very hard to predict the E2E error from the error on individual linear layers. 0xCAF6A435 does look like it has the best overall performance, at least for this model. The setting is stored per tensor, though, so there's room for using individual multipliers per tensor (must be the same across all parallel tensors in block-sparse layers though) so potentially some E2E optimization could produce an even better final model. I'll do some tests with the addition-less 1MAD next. 1MUL is fine I guess, though I'd probably prefer MUL1 or something, to avoid labels starting with a numeric character. |
Okay, so some preliminary results with the addition-free 1MAD variant, 0xAD9A2EC5 multiplier: Very smooth distribution at least, despite the gaps, but that's a given when summing four more-or-less uniform random numbers. MCG indeed seems to be just as good as an LCG for this purpose, which I guess also makes sense since adding a constant isn't going to make any of those four components more uniform anyway. At K=1, neighboring states are still highly correlated, though that probably doesn't mean much for L3-1B, where perplexity at 1 bpw is on the order of 1000 regardless. So it's not plotted here, either: KLD looks a little nicer: So it does fairly well at 2bpw, though the difference isn't as pronounced at higher bitrates. Presumably this is because the quantizer still scales the values down to use the denser middle portion of the codebook, since the weights don't end up being perfectly Gaussian. I'll need to clean up the code a bit and split the three different modes into individual kernel instances since E2E performance seems to take a 5% hit if I have a single kernel branching off to three different matmul functions. Then I'll push the changes probably tomorrow, with some cleanup and warnings to make it clear models created with these settings may or may not be supported in future versions. Next step I guess is to evaluate the performance of each CB function. The MUL1 version might end up being faster, I'm not sure. It does reduce to very few instructions with the template <int cb>
__device__ inline half decode_pcb(uint32_t x, uint32_t mult)
{
if constexpr (cb == 0) // 3INST LCG, original
{
x *= 89226354u;
x += 64248484u;
asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x));
half2_uint32 xu(x);
return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2));
}
if constexpr (cb == 1) // 3INST MCG
{
x *= mult;
asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x));
half2_uint32 xu(x);
return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2));
}
if constexpr (cb == 2) // 1MAD MCG
{
x *= mult;
uint32_t sum;
asm volatile ("vabsdiff4.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(sum) : "r"(x), "r"(0), "r"(0) : );
const __half k_inv_h = __ushort_as_half(0x1eee); // 0.00677 = 1/147.7
const __half k_bias_h = __ushort_as_half(0xc2e7); // -3.452 = -510.0 * k_inv_h
return __hfma(__int2half_rn(sum), k_inv_h, k_bias_h);
}
} But I still haven't profiled yet or checked the SASS it actually compiles to in the end. |
Very happy with compromise to allow per-tensor decoder+multiplier selection. Gives lots of latitude to people like me (or even more motivated tinkerers) to search for both better general and better specific multipliers. Are you actually seeing a speed boost with vabsdiff4 over dpa4? I struggle to measure a difference. It seemed dpa4 was actually faster(??) when I use synthetic benchmarks. Does vabsdiff4 just play nicer with other machinery in EXL3 during actual execution? In my attempts to find a convincingly better multiplier across the board, I tried some totally new ideas (like scrambling bits with MCG then using atanh to make the Gaussian) or using a 64-bit multiplication and doing a nested dpa4 inside the other dpa4 to get a smoother distribution. The 64-bit version does perform bit better but less than I would have hoped. It should be a 28% better Gaussian, which I think it is on average, but the best multipliers in this setting aren't 28% better than the best ones with a double wide multiplication and a nested dpa4 instruction to sum the high half of a 64-bit input. A less costly idea I had was to simply use different coefficients on the dot product (if we're already using dpa4). It's easy to show that w = {1, 1, 1, 1} is "optimal" in terms of Gaussian shape and all other sets of unmatched multipliers have strictly worse distributions that are less Gaussian. But that's only for all positive coefficients! If negative values are allowed, there's an entire family of dot product weights where the magnitudes are still all 1 (but some are -1 so the distribution can be slightly improved). Each of these variants have equal theoretical potential to make good Gaussians, but in practice, some contain more distributions that have naturally lower MSE over the limited 2^16 input range we care about:
[ Side note: MUL1_8 - MUL1_F = MUL1_0 - MUL1_7 -- but mirrored across y-axis. I assume mirrored distributions perform indistinguishably but I haven't confirmed. ] Avg MSE of top performing multipliers in each family has measurable performance differences. This is only possible because of the condensed input range we're dealing with. If we were sampling all 2^32 inputs, all these variants would be equivelent and just rearrangements of the orderings of the same points in slightly different orders. But because we're only using 0.0015% of the total input range of the MCG, these variants are different for us. MUL1_6 looks particularly promising as a place to find an all-around better multiplier. Neat side effect is there's no bias to account for in some cases (not that it matters with dpa4). But MUL1_6 is a little more elegant as pseudocode: x = x * a Not certain I support using my specific 147.71 std deviation factor in a fully generalized algorithm template by the way. Might make sense to fall back to 147.8 so more multipliers will have lower MSE error across the entire range of possible multipliers. If you're seeing better performance with the best values feel free to keep it though. Let me know if dpa4 is faster in your testing. If so, can you allow the full MUL1_x family of variants with both adding and subtracting of each byte in the sum? This extra knob gives a bigger search space to find better performing distributions for the decoder and there's already preliminary reason to believe the default MUL1_0 space is actually the least promising of all the possible variants (in terms of avg MSE / min MSE over our 2^16 input space). They should all be the same speed too. Of course, a few days ago I also found a potentially better MUL1_0 multiplier too (0xBCAD47C5): Doesn't look much better on the synthetic MSE charts but when I measured PLL and KLD saw more improvements. A lot were sub-2bit and over 4bit. But maybe it can buck the trend and thrive in the 2.25 bit Bermuda Triangle where all the other multipliers struggle so far? Hope to have some great 1MUL_6 multipliers in another day or so. Can we keep dpa4 for 1MAD and get the other families of MUL1 for free? Or is vabsdiff4 actually faster? |
I'm having a real hard time figuring out which of VABSDIFF4 or DP4A is meant to be faster. NVIDIA are still very reluctant to document any of this, and the best I can find is this table which doesn't clarify all that much. VABSDIFF4 apparently has a throughput of 64 results per clock cycle per SM on recent GPUs, but there's also evidence to the contrary so who knows, really... with documentation like this, I guess you can't blame ChatGPT for hallucinating sometimes. (: DP4A might be the way to go. It requires loading a second 32-bit operand to get the sum as a dot product where VABSDIFF4 can use rz, in theory, if it actually still exists. The ability to flip two of the coefficients to remove the bias would probably make up for that. However, in your MUL1_6 version, 0x01FFFF01 is a signed char4 operand, while the input presumably is still meant to be unsigned. I don't see a version of __dp4a that mixes signed and unsigned operands, but maybe inline PTX would allow that? It wouldn't work if they're both unsigned, and if they're both signed then each byte is uniform in the range of -128 to 127 so it shouldn't matter if you're multiplying by 1 or -1..? Anyway, I'll profile it all I guess. Even if VABSDIFF4 is emulated in multiple instructions it could still be faster in theory, depending on latencies, register pressure, pipelining and all that. We'll see. As for the scale, it probably doesn't matter at all. The main concern is keeping all the intermediate values in the GEMM kernel within a reasonable range for FP16. Since scaling is dynamic, a difference of a few percent won't be the thing that causes any over- or underflows in the kernel, and in the end the global scale will just adjust to compensate for the lowest immediate MSE. Overall, I'm a little skeptical of relying too much on synthetic data. Even aside from the crazy outliers in a few models, I'm not convinced random noise is a suitable proxy for the regularized 128x128 tiles. I'm also cautious about adding too many layers of complexity for marginally better results. The whole point is to make SOTA quants more accessible, so in that regard creating a bunch of new hyperparameters really needs to be justified. And as for that justification, converting one of these larger models can take upwards of 10 hours per attempt, followed by potentially hours of benchmarking, just to get one data point on a graph that ultimately may not end up having a trend line at all, two weeks later. Then back to the drawing board, try another LCG/MCG, find that it works better for model A but not for model B, then do a bunch of experiments to see if maybe it's possible to automatically distinguish between the two types... After all that, the next challenge is communicating it all. I still suffer from the mistake I made adding the calibration dataset as an easily accessible parameter in ExLlamaV2. So many people immediately mistook it for a finetuning option and started making models "calibrated" for roleplaying, coding, etc. The misconception survives to this day. So while it's entirely possible to have per-tensor options, there needs to be a compelling reason for the added complexity and a lot of thought put into how the functionality is exposed. Ideally it would be automatic. Like an ideal codebook per bitrate or something, and then it's just the default from now on. But getting overly ambitious with it, like "this codebook works great for the attn Q projection" is kind of problematic, since for instance Deepseek uses a low-rank matrix product for that operation, with a norm layer in the middle just to keep it interesting. Who knows what's next. :| |
Swept across most parameter space for generalized MUL1_x decoders. Here's 4 top multipliers: Using radar plot in lieu of 6 boxplots. Note: unbalanced scale for each of the 6 bitrates. It's warranted specifically because there's very little that can be done to improve upon 1MAD at the lowest bitrates. For example, here's some of the values I ran more in-depth testing on at 1-bit: Getting a 0.5% MSE reduction vs 1MAD Baseline is a big chunk of the remaining gap between regular 1MAD and the best possible code that could be created with a perfectly hand-tuned L=16 codebook.
Radar plot above shows around a third of the gap between 1MAD and L=16 optimal performance at each bitrate. So the ones with better coverage really are showing performance that's ~20% of all the MSE reduction that can be had beyond default 1MAD. Interestingly, MUL1_0 0xAD9A2EC5 (the multiplier I found a few days ago) is one of the 4 best versions among all possible MUL1_x variants. And there's another top 4 "standard" MUL1_0 variant with excellent properties: 0x83DCD12D Don't want to blindly fetishize synthetic performance characteristics but 0x83DCD12D has a remarkable "figure of merit" across the entire 256-dimenional space (0.461265) which is tied for 2nd among all 32-bit MCG multipliers. This was calculated separately so the fact it came up again with such pristine MSE properties when actually decoding QTIP feels promising. Statistical quality of 0x83DCD12D has no weak dimensions. So if it performs well in benchmarking, there's reason to expect this particular MCG will be robust in any possible setting. Top variant + 4th best variant happened to both be "unbiased" MUL1 variants that subtract 2 bytes. These are definitely possible to run directly with _dpa4. I was only able to score all these MUL1 variants so quickly by using a CUDA kernel with _dpa4 that allowed negative coefficients. Any of these 4 variants could be great. They each have different strengths. All of them are better than regular 3INST or regular 1MAD with few exceptions. As far as exposing per-tensor calibration options, yeah there will be a few dopes who use it to make poetry calibrated quants or cringe stuff like that. But it's probably more likely you'll attract high-quality quant tinkerers like @bartowski1182 or @ubergarm and get them interested in experimenting with EXL3 improvements if they see extra knobs to play with to put together thoughtful recipes to intelligently shave more KLD & PPL off their quants. The good from what the best people do with it should outweigh what others do with it. The entire quantizing space is getting much more intelligent. Having good partners who can tweak and extend EXL3 as it matures could be just as valuable as having an easily understandable format. And it also doesn't require us to be certain about these particular multipliers right now. As a middle-ground, as long as the multipliers are stored in the models and 1MAD and 3INST are supported, smart folks can always compile their own EXL3 dev branches to make these specialized quants. Then we can change the default multiplier down the line if someone ever finds a surprisingly better one... or people who know enough to compile their own EXL3 branch can make custom quants to experiment and find better recipes (down to particular decoder parameters in different blocks across the same quant). Let me know what you think of the 3 new variants. If none of them are perfect across the board, my suggestion for a "fixed" strategy would be something like 1MAD with 0xAD9A2EC5 for 3 bit and lower + 3INST 0xCAF6A435 for 3.5bit and higher. Let people who are technical enough to compile EXL3 on their own quant with whatever per-tensor settings we want until we find slightly more nuanced decoder settings that slightly improve the entire EXL3 format in a few months. Happy to sweep through the search space again with less-abstract inputs if you can help provide a small library of 128 x 128 post-rotation tiles to optimize parameters against. |
I'm almost ready to push the changes, but there are some complications, sadly. First off, the MCG variant of 3INST is noticeably slower with a non-constant multiplier. The difference end-to-end is on the order of 4%, which is not a very nice tradeoff. The slowdown I think comes from the added register pressure and scheduling difficulties around it, which apparently more than outweighs the savings from one extra constant load in the LCG version. That's not a total dealbreaker if we can narrow it down to one multiplier and use that as a constant for the MCG path, but chasing marginal gains at a significant cost to performance, that I can't really agree with. The feedback I'm getting elsewhere is never about the precision of the format (which is already SOTA) but always about performance. Especially the ALU bottleneck on Ampere, which is affected by this. Worse still, I've tried and profiled both the DP4A and VABSDIFF4 versions of the MUL1 kernel, and while there does appear to be a VABSDIFF4 in the instruction set from Ampere onwards, and it is considerably faster than DP4A, it's still some 15% slower than 3INST with an MCG. That's measured end-to-end, so the GEMM kernel itself is maybe 30% slower overall with this change. DP4A knocks off another 5% or so. The way I interpret that, either the int-to-half conversion is itself slow or it stalls because VABSDIFF4 has high latency, and there isn't enough to do otherwise to hide it. There might be a bit of time to shave off by using the magic number trick (0x6800+n maps to 1024.0+n, for precisely the 0..1023 interval that would be relevant here,) but it's not looking like 1MAD in general is going to match the performance of 3INST. I'll push the changes in a little bit either way, and it's in there as an experimental feature, with big red warnings when enabled. (: Then at least it'll be easy to make some test models with different settings. And since it's a per-tensor setting (for now!), you can mix and match by adding a little bit of logic to the conversion script, e.g. to select a method based on bitrate or whatever. |
Okay, so I pushed it now. It wasn't a small change :) Now, there's been a surprising development with MUL1. It seems like it was in fact the int-to-float16 conversion that was problematic, and using the magic number trick put the speed on par with 3INST after all: So it may be viable still. Might still be best to hunt for one ideal setting per bitrate, but at least it's there to experiment with now. Heed the big red warning of course. I added two new arguments to convert.py:
The setting is still applied per-tensor, though, so there's room to experiment with mixing and matching. The settings are passed down here, in exllamav3/conversion/convert_model.py (which is what convert.py wraps): for linear in linears:
quant_args = {
"seed": idx,
"K": strategy[linear.key], # <-- bitrate here
"devices": devices,
"device_ratios": device_ratios,
"apply_out_scales": args["apply_out_scales"],
}
if args.get("mcg_multiplier"):
quant_args.update({
"mcg_mult": int(args["mcg_multiplier"], 0)
})
if args.get("mul1_multiplier"):
quant_args.update({
"mul1_mult": int(args["mul1_multiplier"], 0)
})
... |
I made a quick-n-dirty mapping of bitrate to 👈 Secret Recipe
This secret recipe shaved Qwen3-14B so hard that the exl3 3.47bpw perplexity is better than the bf16. 🤠🪒🐃 For calculating perplexity on the exl3's I used It definitely takes longer to cook one of these exl3 bad bois on my 3090TI FE 24GB VRAM than even ik's iqN_k quants. Anecdotally inferencing with the chat.py app seems slower, but I haven't looked into proper speed benchmarking yet or if there is even some kind of API endpoint haha... Nope, just did I also kicked the tires on facebookresearch/ParetoQ 2bit QAT fine tuning a bit and compared a exl3 2.0bpw. Not sure if something like ParetoQ could/should ever be combined with exl3, but with numbers this good probably somewhat less urgent need for me personally to research QAT haha... I have a rough guide from my experience getting this setup and might share in a discussion after cleaning up a bit and looking around more at what functionality exists. Thanks for sharing your work openly, I definitely enjoy all the extra knobs to play with. 😆 Cheers! † I am not serious about the name, quant naming convention, nor use of poetry for calibration datasets. Its kind of late here and I am feeling silly. 😹 |
I wrote the eval/compareq.py script specifically to compare perplexity across different quant methods because there's no agreed-upon methodology between frameworks. So be mindful that llama.cpp does a very different test. wiki.test.raw should be roughly comparable to what I use in compareq.py and ppl.py (which is the wikitext2 test split served by the datasets library, all "text" fields joined by "\n\n"). But llama.cpp still takes a very different approach. For instance it evaluates every token with a warm context, which makes a huge difference. So not sure if the comparison you're doing is apples-to-apples..? Either way, it's also not really a good sign if perplexity ends up lower than the original. In most cases lower perplexity is good, and to some extent you can use a model's decreased ability to predict a given test set as a proxy measure for how damaged it is. But it's never meant to improve, and the fact that it sometimes does only goes to show that perplexity is kind of measuring the wrong thing. Suppose you have a multilingual model, for instance, and you somehow quantize it in a way that completely nukes its ability to work in other languages than English. Normally the model would always have some residual signals from "circuits" that deal with French, German, Chinese etc., but if those signals are suppressed, or noise is introduced that amounts to a kind of destructive interference, you might end up with a more focused "English signal." And since wikitext is an English dataset, the score you measure isn't a good indication of how damaged the model actually is. Same with other aspects like instruction following, alignment, Star Trek trivia, or whatever else isn't being tested for. Here's an example with Qwen3-30B-A3B, an MoE model that's very hard to calibrate because it routes the hidden state very sparsely through only 8 of 128 experts on each layer: It's all over the place. If you test something like KL divergence you get a much clearer picture of what's going on: KLD of 0.022 for the 4bpw model is still very low, but it's not negative. The KLD test is more robust in the sense that it doesn't allow for a negative score. It assumes the original model is the ground truth and then measures the cost you would incur for pretending the quantized model is true instead. As for an API endpoint, that's what TabbyAPI is for. It's an OAI compatible server you can plug into whatever frontend to serve EXL2, EXL3 and GPTQ models (and all non-quantized HF models of architectures supported by V2 or V3.) I'm not a big fan of any method that involves finetuning, partly because they're very expensive to use (converting one large model shouldn't cost more than buying a whole GPU) and partly because it's very hard to judge their performance. The line between "it's just more intelligent rounding" and "we're actually just retraining the model" becomes blurred, and it's too easy to target a benchmark and lose other functionality in the process. I saw this happen in a bad way a while back with (I think it was) AutoRound, which managed to improve perplexity on Llama3.1-8B-instruct well below the level of the original model while also tanking its performance on HumanEval. Because it was essentially finetuning on wikitext as it was quantizing the model, and for L3-instruct especially, much of its coding ability seems to be realized in finetuning/RLHF (compare to L3-base and you'll see the very stark difference.) I haven't looked into ParetoQ specifically, but it could be worth adding to the charts. At least if it works with Transformers and there are a few reference models available. Couldn't seem to find any..? |
Thanks for the detailed response! I've had some time to read through more of the documentation and code and figuring out some more things.
Agreed it probably isn't exactly comparable. I see EDIT: I re-ran some of these using a consistent
Yeah, I've seen perplexity on some ~4bpw quants go below the "baseline" bf16 on While I use PPL and KLD to inform a specific architecture's quants relative to each other, it is definitely nice to get some other benchmark data too for confirmation and not fall into over-training a specific test corpus etc. totes.
Ahh right, I got that fired up and tried the exl3 branch but was getting some gibberish, possibly because my quant was cooked with this hours old experimental features against all the warnings haha... I'll mess with it more eventually and updated my rough getting started with exl3 guide with notes.
I can't find any ParetoQ models either, so asked the github repo author for information. Thanks again! I'll play around some more trying to compare exl3 with various GGUF quants and benchmarks. Finally a few things I am quite happy about:
Cheers! |
Depending on how you install Tabby, it may be using the latest release build which doesn't support the new quant formats and will quietly assume that everything is still encoded to the old 3INST codebook. Make sure the dev branch is installed in Tabby's venv and you should be okay, I think. There's also YALS by the same guy who made TabbyAPI. It's an OAI server too, but it wraps llama.cpp instead, so you could set the two up next to each other and swap between them to compare EXL3 and GGUF models that way, perhaps.
Yes. The compareq.py script also lets you compute KLD, though it requires loading the base model. Set "out_logits" in the spec, pointing to a directory where logits from the first 10 rows will be saved (this is usually a huge file, hence the 10 row limit for now). Then KLD will be computed against those logits for all other models in the same run, and you can plot that curve instead of perplexity by adding
EXL3 isn't strictly weights only. It uses a built-in dataset which attempts to be very broad, mixing a bunch of different languages and types of text with strings of random tokens. It uses LDLQ which is equivalent to OPTQ/GPTQ but more efficient. It's the same method used in QTIP and QUIP#. So it still needs some input data to build a Hessian matrix for each linear layer, but all it means in practice is that it more intelligently manages the "rounding" error when tiles can't be perfectly encoded, as opposed to just discarding it. It doesn't affect the selection of bitrates per tensor. |
Yup that fixed it, I just needed to install my exllamav3 dev branch into the venv of tabby and all is well. Yay! Thanks for heads up on YALS too.
Great, I was able to use Managed to run a few more numbers and looks fairly reasonable. 👈 Some KLD and PPL Graphs of Qwen3-14B exl3's and GGUFsFor the 3.5bpw quant here it is using a dynamic mapping either mcg or mul1 multiplier based on the bpw allocation chosen by the strategy for the given tensor. The mapping comes from the above chart for Llama-3.2-1B-Instruct and picking the multiplier for the lowest PPL/KLD for each bpw datapoint in the chart. I realize you have already implemented some clever "dynamic" quantization going on with "allocation" permutations for Okay thanks for letting me bomb in on this PR. Looking forward to setting up some better scripts to explore more exl3 quants alongside some of my fav GGUF flavors. Cheers and happy weekend! |
Used the dev branch some over the weekend. Independently made my own scaffolding to test all the existing and several new multipliers on each tensor. Saw you added some pytest routines. Are those just for actual pass/fail testing to make sure code works? Maybe we could add some more instrumentation there so it actually saves the results and produces reports showing how much better different multipliers are than others? I've gone through all the layers on a ton of different models. Haven't found a single tensor where default 3INST ever wins. Is that what you're seeing too? At a minimum, 3INST 0xCAF6A435 has strictly lower proxy error, lower MSE, lower KLD, and (usually) lower PPL. The occasional micro-regressions in PLL tend to come with strictly better KLD and MSE so just seems unambiguously better in all cases. Interestingly, one of the places it helps the most (% wise) is in the 6-bit lm_head where it usually reduces error 2-3% more. Anyway, hopefully we're moving towards replacing default QTIP params before full launch, right? |
While I didn't exhaustively compare mcg and mul1 parameters, I did compare a lot of Two over-packed graphs for PPL and KLD as well as raw data below if anyone is curious. I found it interesting that the GGUF bf16 did have a very slight KL-Divergence from the baseline HF transformers safetensors. You can compare these scores against the Qwen3-30B-A3B provided above by turboderp which presumably did not use the 👈 Qwen3-30B-A3B exl3 and GGUF plots and data for PPL and KLD[
{
"label": "HF BF16",
"layer_bpw": 16.0,
"head_bpw": 16,
"vram_gb": 56.29052734375,
"ppl": 8.896963416807232,
"kld": 0.0
},
{
"label": "EXL3 2.0bpw H6 mcg 0xCAF6A435",
"layer_bpw": 2.0341227864035547,
"head_bpw": 6.007917910337247,
"vram_gb": 7.300313476473093,
"ppl": 67.25162091132263,
"kld": 2.2775095161326933
},
{
"label": "EXL3 2.5bpw H6 mcg 0xCAF6A435",
"layer_bpw": 2.5287238241261742,
"head_bpw": 6.007917910337247,
"vram_gb": 9.022481445223093,
"ppl": 10.095238979406716,
"kld": 0.1239086788918659
},
{
"label": "EXL3 3.0bpw H6 mcg 0xCAF6A435",
"layer_bpw": 3.0337020880442784,
"head_bpw": 6.007917910337247,
"vram_gb": 10.780782226473093,
"ppl": 9.182588255526754,
"kld": 0.0687279472209428
},
{
"label": "EXL3 3.5bpw H6 mcg 0xCAF6A435",
"layer_bpw": 3.528303125766898,
"head_bpw": 6.007917910337247,
"vram_gb": 12.502950195223093,
"ppl": 9.247539261362629,
"kld": 0.0343391250661969
},
{
"label": "EXL3 4.0bpw H6 mcg 0xCAF6A435",
"layer_bpw": 4.033281389685002,
"head_bpw": 6.007917910337247,
"vram_gb": 14.261250976473093,
"ppl": 9.042701868479641,
"kld": 0.024368329403625617
},
{
"label": "EXL3 4.5bpw H6 mcg 0xCAF6A435",
"layer_bpw": 4.527882427407621,
"head_bpw": 6.007917910337247,
"vram_gb": 15.983418945223093,
"ppl": 8.927166917597463,
"kld": 0.013147117029470714
},
{
"label": "EXL3 5.0bpw H6 mcg 0xCAF6A435",
"layer_bpw": 5.032860691325726,
"head_bpw": 6.007917910337247,
"vram_gb": 17.741719726473093,
"ppl": 8.942674618981622,
"kld": 0.009545976939601178
},
{
"label": "EXL3 6.0bpw H6 mcg 0xCAF6A435",
"layer_bpw": 6.03243999296645,
"head_bpw": 6.007917910337247,
"vram_gb": 21.222188476473093,
"ppl": 8.914982554441801,
"kld": 0.005315908464860837
},
{
"label": "GGUF BF16",
"layer_bpw": 16.00710575546037,
"head_bpw": 16.0,
"vram_gb": 29.64445686340332,
"ppl": 8.90307838931857,
"kld": 0.0018382440041680701
},
{
"label": "GGUF Q8_0 imat",
"layer_bpw": 8.510052079542378,
"head_bpw": 8.5,
"vram_gb": 29.939552307128906,
"ppl": 8.895864961372073,
"kld": 0.00376526105450302
},
{
"label": "GGUF UD-Q4_K_XL imat",
"layer_bpw": 4.620486385263758,
"head_bpw": 6.5625,
"vram_gb": 16.493436813354492,
"ppl": 9.006967324027388,
"kld": 0.02983071104526298
},
{
"label": "GGUF UD-IQ2_M imat",
"layer_bpw": 2.8071527555603777,
"head_bpw": 6.5625,
"vram_gb": 10.113798141479492,
"ppl": 10.04768989571527,
"kld": 0.1389737648072192
},
{
"label": "GGUF UD-Q4_K_XL 128k imat",
"layer_bpw": 4.620486385263758,
"head_bpw": 6.5625,
"vram_gb": 16.493436813354492,
"ppl": 9.147260263984773,
"kld": 0.06502522703666595
},
{
"label": "GGUF UD-IQ4_XS 128k imat",
"layer_bpw": 4.266403692463578,
"head_bpw": 6.5625,
"vram_gb": 15.247709274291992,
"ppl": 8.98860554080009,
"kld": 0.06928144145354515
},
{
"label": "GGUF Q4_K_S imat",
"layer_bpw": 4.691649893635455,
"head_bpw": 6.5625,
"vram_gb": 16.743803024291992,
"ppl": 9.039136994185384,
"kld": 0.024144081457776672
},
{
"label": "GGUF IQ2_M imat",
"layer_bpw": 2.702843329508759,
"head_bpw": 5.5,
"vram_gb": 9.708330154418945,
"ppl": 9.637042524624235,
"kld": 0.13621351244574528
},
{
"label": "GGUF Q4_K_M imat",
"layer_bpw": 4.863105026067801,
"head_bpw": 6.5625,
"vram_gb": 17.347013473510742,
"ppl": 9.038357988835012,
"kld": 0.01969047055205346
},
{
"label": "GGUF IQ3_M ed imat",
"layer_bpw": 3.681662516695381,
"head_bpw": 3.4375,
"vram_gb": 13.077281951904297,
"ppl": 9.36637604254197,
"kld": 0.08605257654904011
},
{
"label": "GGUF Q4_K_M ed imat",
"layer_bpw": 4.473586414830737,
"head_bpw": 4.5,
"vram_gb": 15.90190315246582,
"ppl": 9.246278950648936,
"kld": 0.037276722648089344
},
{
"label": "GGUF Q5_K_M ed imat",
"layer_bpw": 5.33928476917054,
"head_bpw": 5.5,
"vram_gb": 18.98381233215332,
"ppl": 9.045580650795506,
"kld": 0.01762934465542933
}
] |
Thanks for sharing Qwen 30B-A3B results for mcg=0xCAF6A435. EXL3 2.00 bpw data seems shockingly high though? Hmm... New 3INST mcg=0xB83EA16D (located today) seems to slightly outperform even 0xCAF6A435 in real world EXL3 quantization. Another step towards a more widely dominant default? |
Yeah not sure what happened there. It was one of the first I made, but my procedure was consistent for this set. Could be 2.0bpw is more sensitive to the calibration given the challenges turboderp mentioned above:
I quantized a I'll spare you the crowded graph and just give the raw data: {
"label": "EXL3 2.25bpw H6 mcg 0xB83EA16D",
"layer_bpw": 2.283176215095183,
"head_bpw": 6.007917910337247,
"vram_gb": 8.167500976473093,
"ppl": 9.630357382083758,
"kld": 0.1378025879524722
}, Okay, everyone is busy with R1-05-28 tonight lol... Anyway, thanks again for the freshest of exl3 magic numbers! |
Summary: Swapped out original LCG params in 3INST for a better, computationally screened, optimal MCG multiplier
0.2-3.9% lower PPL
1.25% faster processing (addition removed)
Details:
3INST's default LCG params are knowably suboptimal:
x = 89226354 * x + 64248484
See Computationally easy, spectrally good multipliers for congruential pseudorandom number generators for deeper discussion
Demanding LCGs have good spectral properties is not always necessary but for a 256-diminsion Trellis Decoder, spectral quality in higher dimensions is actually a key performance requirement.
Screened all 45458 spectrally good LCG multipliers and all 49507 spectrally good MCG multipliers uploaded to github by @vigna
Many LCG multipliers had strictly lower RMS in the decoder but none had universally superior PPL when actually quantizing and decoding real models.
But several MCG multipliers had both lower RMS in a vacuum AND universally better PPL on several models @ multiple bitrates. The best one appears to be 0xFF7A83E5.
Performance Comparison of Multiplier 0xFF7A83E5 Across Models
Also inspected how this impacts the CUDA (removing addition):
Original CUDA PTX:
mad.lo.s32 %r4, %r3, -1997118179, 64248484;
and.b32 %r5, %r4, -1879076865;
xor.b32 %r2, %r5, 996162400;
New CUDA PTX:
mul.lo.s32 %r4, %r3, -8748059;
and.b32 %r5, %r4, -1879076865;
xor.b32 %r2, %r5, 996162400;
Benchmarking reveals small but measurable 1.25% speed increase by eliminating addition.
Even modest 1.25% speedup and a small 0.2-3.9% lower PPL is hard to come by in other places, given how efficient EXL3 is shaping up to be.