Skip to content

Commit 3aa35e7

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
add correctness test, note that precision makes a large difference
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 2f5b1c8 commit 3aa35e7

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.transform import apply_transform_config
4+
from transformers import AutoModelForCausalLM
5+
6+
from llmcompressor.modifiers.transform.template.quip import QUIP
7+
8+
9+
@pytest.mark.parametrize(
10+
"dtype,exp_max,exp_mse", [
11+
(torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501
12+
(torch.float32, 4e-4, 2e-9)
13+
]
14+
)
15+
def test_apply_correctness(dtype, exp_max, exp_mse):
16+
model = AutoModelForCausalLM.from_pretrained(
17+
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype
18+
)
19+
20+
input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()}
21+
with torch.no_grad():
22+
true_output = model(**input)
23+
24+
apply_transform_config(model, QUIP)
25+
with torch.no_grad():
26+
output = model(**input)
27+
28+
assert torch.max(true_output.logits - output.logits) <= exp_max
29+
assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse

0 commit comments

Comments
 (0)