Skip to content

Commit

Permalink
Generate the rust wrapping code.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Feb 17, 2019
1 parent f1f816f commit 2d84da9
Show file tree
Hide file tree
Showing 15 changed files with 13,222 additions and 265 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/target
_build/
data/
gen/.merlin
**/*.rs.bk
*.swp
Cargo.lock
8 changes: 4 additions & 4 deletions c/torch_api_generated.cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -4038,9 +4038,9 @@ void atg_pow1(tensor *out__, tensor self, tensor exponent) {
)
}

void atg_pow2(tensor *out__, scalar self, tensor exponent) {
void atg_pow2(tensor *out__, scalar self_scalar, tensor exponent) {
PROTECT(
auto outputs__ = torch::pow(*self, *exponent);
auto outputs__ = torch::pow(*self_scalar, *exponent);
out__[0] = new torch::Tensor(outputs__);
)
}
Expand Down Expand Up @@ -4073,9 +4073,9 @@ void atg_pow_out1(tensor *out__, tensor result, tensor self, tensor exponent) {
)
}

void atg_pow_out2(tensor *out__, tensor result, scalar self, tensor exponent) {
void atg_pow_out2(tensor *out__, tensor result, scalar self_scalar, tensor exponent) {
PROTECT(
auto outputs__ = torch::pow_out(*result, *self, *exponent);
auto outputs__ = torch::pow_out(*result, *self_scalar, *exponent);
out__[0] = new torch::Tensor(outputs__);
)
}
Expand Down
4 changes: 2 additions & 2 deletions c/torch_api_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,12 @@ void atg_potrs(tensor *, tensor self, tensor input2, int upper);
void atg_potrs_out(tensor *, tensor result, tensor self, tensor input2, int upper);
void atg_pow(tensor *, tensor self, scalar exponent);
void atg_pow1(tensor *, tensor self, tensor exponent);
void atg_pow2(tensor *, scalar self, tensor exponent);
void atg_pow2(tensor *, scalar self_scalar, tensor exponent);
void atg_pow_(tensor *, tensor self, scalar exponent);
void atg_pow_1(tensor *, tensor self, tensor exponent);
void atg_pow_out(tensor *, tensor result, tensor self, scalar exponent);
void atg_pow_out1(tensor *, tensor result, tensor self, tensor exponent);
void atg_pow_out2(tensor *, tensor result, scalar self, tensor exponent);
void atg_pow_out2(tensor *, tensor result, scalar self_scalar, tensor exponent);
void atg_prelu(tensor *, tensor self, tensor weight);
void atg_prelu_backward(tensor *, tensor grad_output, tensor self, tensor weight);
void atg_prod(tensor *, tensor self);
Expand Down
1 change: 1 addition & 0 deletions dune-project
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(lang dune 1.6)
6 changes: 3 additions & 3 deletions examples/basics.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
extern crate torchr;
use torchr::{Kind, Tensor};
use torchr::{Device, Kind, Tensor};

fn grad_example() {
let mut x = Tensor::from(2.0).set_requires_grad(true);
let mut y = &x * &x + &x + 36;
let y = &x * &x + &x + 36;
println!("{}", y.double_value(&[]));
x.zero_grad();
y.backward();
Expand All @@ -14,7 +14,7 @@ fn grad_example() {
fn main() {
let t = Tensor::int_vec(&[3, 1, 4, 1, 5]);
t.print();
let t = Tensor::randn(&[5, 4], Kind::Float);
let t = Tensor::randn(&[5, 4], (Kind::Float, Device::Cpu));
t.print();
(&t + 1.5).print();
(&t + 2.5).print();
Expand Down
13 changes: 7 additions & 6 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*/

extern crate torchr;
use torchr::{no_grad, vision, Kind, Tensor};
use torchr::{no_grad, vision, Device, Kind, Tensor};

static IMAGE_DIM: i64 = 784;
static LABELS: i64 = 10;
Expand All @@ -21,8 +21,9 @@ fn main() {
println!("train-labels: {:?}", m.train_labels.size());
println!("test-images: {:?}", m.test_images.size());
println!("test-labels: {:?}", m.test_labels.size());
let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], Kind::Float).set_requires_grad(true);
let mut bs = Tensor::zeros(&[LABELS], Kind::Float).set_requires_grad(true);
let mut ws =
Tensor::zeros(&[IMAGE_DIM, LABELS], (Kind::Float, Device::Cpu)).set_requires_grad(true);
let mut bs = Tensor::zeros(&[LABELS], (Kind::Float, Device::Cpu)).set_requires_grad(true);
for epoch in 1..200 {
let logits = m.train_images.mm(&ws) + &bs;
let loss = logits.log_softmax(-1).nll_loss(&m.train_labels);
Expand All @@ -35,9 +36,9 @@ fn main() {
});
let test_logits = m.test_images.mm(&ws) + &bs;
let test_accuracy = test_logits
.argmax(-1)
.eq(&m.test_labels)
.to_kind(Kind::Float)
.argmax1(-1, false)
.eq1(&m.test_labels)
.to_kind(&Kind::Float)
.mean()
.double_value(&[]);
println!(
Expand Down
3 changes: 3 additions & 0 deletions gen/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(executables
(names gen)
(libraries base stdio yaml))
Loading

0 comments on commit 2d84da9

Please sign in to comment.