Skip to content

Commit

Permalink
Create tutorial (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Feb 28, 2025
1 parent f07bec7 commit 46427aa
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ model_with_external_data = onnx_safetensors.save_file(model, data_path, base_dir
# This model is a valid ONNX model using external data from the safetensors file
onnx.save(model_with_external_data, os.path.join(base_dir, "model_using_safetensors.onnx"))
```

## Examples

- [Tutorial notebook](examples/tutorial.ipynb)
38 changes: 38 additions & 0 deletions examples/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import onnx
from onnx import TensorProto, helper

# Create a simple model: Y = X + W, where W is an initializer
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3])
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3])

# Create an initializer tensor
weights = np.array([1.0, 2.0, 3.0], dtype=np.float32)
weights_initializer = helper.make_tensor(
name="weights",
data_type=TensorProto.FLOAT,
dims=weights.shape,
vals=weights.flatten().tolist(),
)

# Create a node (Add operation)
node_def = helper.make_node(
"Add",
inputs=["input", "weights"],
outputs=["output"],
)

# Create the graph
graph_def = helper.make_graph(
nodes=[node_def],
name="SimpleGraph",
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3])],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3])],
initializer=[weights_initializer],
)

# Create the model
model_def = helper.make_model(graph_def, producer_name="onnx-safetensors-example")

# Save the model
onnx.save(model_def, "model.textproto")
54 changes: 54 additions & 0 deletions examples/model.textproto
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
ir_version: 10
producer_name: "onnx-safetensors-example"
graph {
node {
input: "input"
input: "weights"
output: "output"
op_type: "Add"
}
name: "SimpleGraph"
initializer {
dims: 3
data_type: 1
float_data: 1.0
float_data: 2.0
float_data: 3.0
name: "weights"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 21
}
276 changes: 276 additions & 0 deletions examples/tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3f786336",
"metadata": {},
"source": [
"# ONNX Safetensors Tutorial\n",
"\n",
"This notebook demonstrates how to use the public API of the `onnx_safetensors` package to load and save ONNX weights using safetensors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1985be7f",
"metadata": {},
"outputs": [],
"source": [
"# !pip install --upgrade onnx-safetensors"
]
},
{
"cell_type": "markdown",
"id": "bedf1490",
"metadata": {},
"source": [
"## Load ONNX model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a54e0bdc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<\n",
" ir_version: 10,\n",
" opset_import: [\"\" : 21],\n",
" producer_name: \"onnx-safetensors-example\"\n",
">\n",
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n",
" <float[3] weights = {1,2,3}>\n",
"{\n",
" output = Add (input, weights)\n",
"}\n"
]
}
],
"source": [
"import onnx\n",
"\n",
"model = onnx.load(\"model.textproto\")\n",
"print(onnx.printer.to_text(model))"
]
},
{
"cell_type": "markdown",
"id": "04626ef6",
"metadata": {},
"source": [
"## Loading tensors from a safetensors file into an ONNX model\n",
"\n",
"We first create a safetensors file with compatible weights, then load these weights into the ONNX model."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b77ddb75",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<\n",
" ir_version: 10,\n",
" opset_import: [\"\" : 21],\n",
" producer_name: \"onnx-safetensors-example\"\n",
">\n",
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n",
" <float[3] weights = {4,5,6}>\n",
"{\n",
" output = Add (input, weights)\n",
"}\n"
]
}
],
"source": [
"import numpy as np\n",
"import safetensors.numpy\n",
"\n",
"import onnx_safetensors\n",
"\n",
"# Create a safetensors file with compatible weights\n",
"# Note that the tensor key \"weights\" matches the name of the tensor in the model\n",
"weights_dict = {\"weights\": np.array([4.0, 5.0, 6.0], dtype=np.float32)}\n",
"safetensors.numpy.save_file(weights_dict, \"weights.safetensors\")\n",
"\n",
"# Now you can replace the weights in the model\n",
"replaced_model = onnx_safetensors.load_file(model, \"weights.safetensors\")\n",
"\n",
"# Notice how the weights have been replaced to [4, 5, 6]\n",
"print(onnx.printer.to_text(replaced_model))"
]
},
{
"cell_type": "markdown",
"id": "cecd56d8",
"metadata": {},
"source": [
"Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "72d642a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<\n",
" ir_version: 10,\n",
" opset_import: [\"\" : 21],\n",
" producer_name: \"onnx-safetensors-example\"\n",
">\n",
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n",
" <float[3] weights = [\"location\": \"weights.safetensors\", \"offset\": \"72\", \"length\": \"12\"]>\n",
"{\n",
" output = Add (input, weights)\n",
"}\n"
]
}
],
"source": [
"replaced_model_with_external_data = onnx_safetensors.load_file_as_external_data(model, \"weights.safetensors\")\n",
"\n",
"print(onnx.printer.to_text(replaced_model_with_external_data))"
]
},
{
"cell_type": "markdown",
"id": "e7ff107f",
"metadata": {},
"source": [
"### Using safetensors as external data for ONNX\n",
"\n",
"You can also save the ONNX model to use safetensors as external data."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6f42a4a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<\n",
" ir_version: 10,\n",
" opset_import: [\"\" : 21],\n",
" producer_name: \"onnx-safetensors-example\"\n",
">\n",
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n",
" <float[3] weights = {1,2,3}>\n",
"{\n",
" output = Add (input, weights)\n",
"}\n"
]
}
],
"source": [
"# First take the onnx model\n",
"print(onnx.printer.to_text(model))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f7d0bf03",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weights saved: {'weights': array([1., 2., 3.], dtype=float32)}\n",
"\n",
"model_with_external_data:\n",
"<\n",
" ir_version: 10,\n",
" opset_import: [\"\" : 21],\n",
" producer_name: \"onnx-safetensors-example\"\n",
">\n",
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n",
" <float[3] weights = [\"location\": \"model.safetensors\", \"offset\": \"72\", \"length\": \"12\"]>\n",
"{\n",
" output = Add (input, weights)\n",
"}\n"
]
}
],
"source": [
"# Save the model to use safetensors as external data. It should contain 1, 2, 3\n",
"model_with_external_data = onnx_safetensors.save_file(model, 'model.safetensors', base_dir='.', replace_data=True)\n",
"print(\"Weights saved:\", safetensors.numpy.load_file('model.safetensors'))\n",
"\n",
"# This is a model referencing safetensors as external data\n",
"print(\"\\nmodel_with_external_data:\")\n",
"print(onnx.printer.to_text(model_with_external_data))"
]
},
{
"cell_type": "markdown",
"id": "8eb20a2e",
"metadata": {},
"source": [
"# Inference with ONNX Runtime"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "39c668a4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output: [array([[2., 4., 6.]], dtype=float32)]\n"
]
}
],
"source": [
"import onnxruntime as ort\n",
"\n",
"onnx.save(model_with_external_data, \"model_with_external_data.onnx\")\n",
"session = ort.InferenceSession(\"model_with_external_data.onnx\")\n",
"output = session.run(None, {\"input\": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})\n",
"print(\"Output:\", output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "onnx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 46427aa

Please sign in to comment.