-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f07bec7
commit 46427aa
Showing
4 changed files
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
import onnx | ||
from onnx import TensorProto, helper | ||
|
||
# Create a simple model: Y = X + W, where W is an initializer | ||
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3]) | ||
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3]) | ||
|
||
# Create an initializer tensor | ||
weights = np.array([1.0, 2.0, 3.0], dtype=np.float32) | ||
weights_initializer = helper.make_tensor( | ||
name="weights", | ||
data_type=TensorProto.FLOAT, | ||
dims=weights.shape, | ||
vals=weights.flatten().tolist(), | ||
) | ||
|
||
# Create a node (Add operation) | ||
node_def = helper.make_node( | ||
"Add", | ||
inputs=["input", "weights"], | ||
outputs=["output"], | ||
) | ||
|
||
# Create the graph | ||
graph_def = helper.make_graph( | ||
nodes=[node_def], | ||
name="SimpleGraph", | ||
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3])], | ||
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3])], | ||
initializer=[weights_initializer], | ||
) | ||
|
||
# Create the model | ||
model_def = helper.make_model(graph_def, producer_name="onnx-safetensors-example") | ||
|
||
# Save the model | ||
onnx.save(model_def, "model.textproto") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
ir_version: 10 | ||
producer_name: "onnx-safetensors-example" | ||
graph { | ||
node { | ||
input: "input" | ||
input: "weights" | ||
output: "output" | ||
op_type: "Add" | ||
} | ||
name: "SimpleGraph" | ||
initializer { | ||
dims: 3 | ||
data_type: 1 | ||
float_data: 1.0 | ||
float_data: 2.0 | ||
float_data: 3.0 | ||
name: "weights" | ||
} | ||
input { | ||
name: "input" | ||
type { | ||
tensor_type { | ||
elem_type: 1 | ||
shape { | ||
dim { | ||
dim_value: 1 | ||
} | ||
dim { | ||
dim_value: 3 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
output { | ||
name: "output" | ||
type { | ||
tensor_type { | ||
elem_type: 1 | ||
shape { | ||
dim { | ||
dim_value: 1 | ||
} | ||
dim { | ||
dim_value: 3 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
opset_import { | ||
version: 21 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3f786336", | ||
"metadata": {}, | ||
"source": [ | ||
"# ONNX Safetensors Tutorial\n", | ||
"\n", | ||
"This notebook demonstrates how to use the public API of the `onnx_safetensors` package to load and save ONNX weights using safetensors." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "1985be7f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !pip install --upgrade onnx-safetensors" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bedf1490", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load ONNX model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "a54e0bdc", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<\n", | ||
" ir_version: 10,\n", | ||
" opset_import: [\"\" : 21],\n", | ||
" producer_name: \"onnx-safetensors-example\"\n", | ||
">\n", | ||
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n", | ||
" <float[3] weights = {1,2,3}>\n", | ||
"{\n", | ||
" output = Add (input, weights)\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import onnx\n", | ||
"\n", | ||
"model = onnx.load(\"model.textproto\")\n", | ||
"print(onnx.printer.to_text(model))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "04626ef6", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loading tensors from a safetensors file into an ONNX model\n", | ||
"\n", | ||
"We first create a safetensors file with compatible weights, then load these weights into the ONNX model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "b77ddb75", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<\n", | ||
" ir_version: 10,\n", | ||
" opset_import: [\"\" : 21],\n", | ||
" producer_name: \"onnx-safetensors-example\"\n", | ||
">\n", | ||
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n", | ||
" <float[3] weights = {4,5,6}>\n", | ||
"{\n", | ||
" output = Add (input, weights)\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import safetensors.numpy\n", | ||
"\n", | ||
"import onnx_safetensors\n", | ||
"\n", | ||
"# Create a safetensors file with compatible weights\n", | ||
"# Note that the tensor key \"weights\" matches the name of the tensor in the model\n", | ||
"weights_dict = {\"weights\": np.array([4.0, 5.0, 6.0], dtype=np.float32)}\n", | ||
"safetensors.numpy.save_file(weights_dict, \"weights.safetensors\")\n", | ||
"\n", | ||
"# Now you can replace the weights in the model\n", | ||
"replaced_model = onnx_safetensors.load_file(model, \"weights.safetensors\")\n", | ||
"\n", | ||
"# Notice how the weights have been replaced to [4, 5, 6]\n", | ||
"print(onnx.printer.to_text(replaced_model))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cecd56d8", | ||
"metadata": {}, | ||
"source": [ | ||
"Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "72d642a5", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<\n", | ||
" ir_version: 10,\n", | ||
" opset_import: [\"\" : 21],\n", | ||
" producer_name: \"onnx-safetensors-example\"\n", | ||
">\n", | ||
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n", | ||
" <float[3] weights = [\"location\": \"weights.safetensors\", \"offset\": \"72\", \"length\": \"12\"]>\n", | ||
"{\n", | ||
" output = Add (input, weights)\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"replaced_model_with_external_data = onnx_safetensors.load_file_as_external_data(model, \"weights.safetensors\")\n", | ||
"\n", | ||
"print(onnx.printer.to_text(replaced_model_with_external_data))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7ff107f", | ||
"metadata": {}, | ||
"source": [ | ||
"### Using safetensors as external data for ONNX\n", | ||
"\n", | ||
"You can also save the ONNX model to use safetensors as external data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "6f42a4a5", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<\n", | ||
" ir_version: 10,\n", | ||
" opset_import: [\"\" : 21],\n", | ||
" producer_name: \"onnx-safetensors-example\"\n", | ||
">\n", | ||
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n", | ||
" <float[3] weights = {1,2,3}>\n", | ||
"{\n", | ||
" output = Add (input, weights)\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# First take the onnx model\n", | ||
"print(onnx.printer.to_text(model))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "f7d0bf03", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Weights saved: {'weights': array([1., 2., 3.], dtype=float32)}\n", | ||
"\n", | ||
"model_with_external_data:\n", | ||
"<\n", | ||
" ir_version: 10,\n", | ||
" opset_import: [\"\" : 21],\n", | ||
" producer_name: \"onnx-safetensors-example\"\n", | ||
">\n", | ||
"SimpleGraph (float[1,3] input) => (float[1,3] output) \n", | ||
" <float[3] weights = [\"location\": \"model.safetensors\", \"offset\": \"72\", \"length\": \"12\"]>\n", | ||
"{\n", | ||
" output = Add (input, weights)\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Save the model to use safetensors as external data. It should contain 1, 2, 3\n", | ||
"model_with_external_data = onnx_safetensors.save_file(model, 'model.safetensors', base_dir='.', replace_data=True)\n", | ||
"print(\"Weights saved:\", safetensors.numpy.load_file('model.safetensors'))\n", | ||
"\n", | ||
"# This is a model referencing safetensors as external data\n", | ||
"print(\"\\nmodel_with_external_data:\")\n", | ||
"print(onnx.printer.to_text(model_with_external_data))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8eb20a2e", | ||
"metadata": {}, | ||
"source": [ | ||
"# Inference with ONNX Runtime" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "39c668a4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Output: [array([[2., 4., 6.]], dtype=float32)]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import onnxruntime as ort\n", | ||
"\n", | ||
"onnx.save(model_with_external_data, \"model_with_external_data.onnx\")\n", | ||
"session = ort.InferenceSession(\"model_with_external_data.onnx\")\n", | ||
"output = session.run(None, {\"input\": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})\n", | ||
"print(\"Output:\", output)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "onnx", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.13.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |