NousResearch/tch-rs
forked from LaurentMazare/tch-rs
Captured source
source ↗NousResearch/tch-rs
Description: Rust bindings for the C++ api of PyTorch.
Language: Rust
License: Apache-2.0
Stars: 6
Forks: 2
Open issues: 1
Created: 2024-09-11T18:35:57Z
Pushed: 2026-03-10T15:40:39Z
Default branch: main
Fork: yes
Parent repository: LaurentMazare/tch-rs
Archived: no
README:
tch-rs
Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.
   changelog
The code generation part for the C api on top of libtorch comes from ocaml-torch.
Getting Started
This crate requires the C++ PyTorch library (libtorch) in version *v2.7.0* to be available on your system. You can either:
- Use the system-wide libtorch installation (default).
- Install libtorch manually and let the build script know about it via the
LIBTORCHenvironment variable. - Use a Python PyTorch install, to do this set
LIBTORCH_USE_PYTORCH=1. - When a system-wide libtorch can't be found and
LIBTORCHis not set, the
build script can download a pre-built binary version of libtorch by using the download-libtorch feature. By default a CPU version is used. The TORCH_CUDA_VERSION environment variable can be set to cu117 in order to get a pre-built binary using CUDA 11.7.
System-wide Libtorch
On linux platforms, the build script will look for a system-wide libtorch library in /usr/lib/libtorch.so.
Python PyTorch Install
If the LIBTORCH_USE_PYTORCH environment variable is set, the active python interpreter is called to retrieve information about the torch python package. This version is then linked against.
Libtorch Manual Install
- Get
libtorchfrom the
PyTorch website download section and extract the content of the zip file.
- For Linux and macOS users, add the following to your
.bashrcor equivalent, where/path/to/libtorch
is the path to the directory that was created when unzipping the file.
export LIBTORCH=/path/to/libtorch
The header files location can also be specified separately from the shared library via the following:
# LIBTORCH_INCLUDE must contain `include` directory. export LIBTORCH_INCLUDE=/path/to/libtorch/ # LIBTORCH_LIB must contain `lib` directory. export LIBTORCH_LIB=/path/to/libtorch/
- For Windows users, assuming that
X:\path\to\libtorchis the unzipped libtorch directory. - Navigate to Control Panel -> View advanced system settings -> Environment variables.
- Create the
LIBTORCHvariable and set it toX:\path\to\libtorch. - Append
X:\path\to\libtorch\libto thePathvariable.
If you prefer to temporarily set environment variables, in PowerShell you can run
$Env:LIBTORCH = "X:\path\to\libtorch" $Env:Path += ";X:\path\to\libtorch\lib"
- You should now be able to run some examples, e.g.
cargo run --example basics.
Windows Specific Notes
As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.
It is recommended to use the MSVC Rust toolchain (e.g. by installing stable-x86_64-pc-windows-msvc via rustup) rather than a MinGW based one as PyTorch has compatibilities issues with MinGW.
Static Linking
When setting environment variable LIBTORCH_STATIC=1, libtorch is statically linked rather than using the dynamic libraries. The pre-compiled artifacts don't seem to include libtorch.a by default so this would have to be compiled manually, e.g. via the following:
git clone -b v2.7.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1 cd pytorch-static USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build # export LIBTORCH to point at the build directory in pytorch-static.
Examples
Basic Tensor Operations
This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.
use tch::Tensor;
fn main() {
let t = Tensor::from_slice(&[3, 1, 4, 1, 5]);
let t = t * 2;
t.print();
}Training a Model via Gradient Descent
PyTorch provides automatic differentiation for most tensor operations it supports. This is commonly used to train models using gradient descent. The optimization is performed over variables which are created via a nn::VarStore by defining their shapes and initializations.
In the example below my_module uses two variables x1 and x2 which initial values are 0. The forward pass applied to tensor xs returns xs * x1 + exp(xs) * x2.
Once the model has been generated, a nn::Sgd optimizer is created. Then on each step of the training loop:
- The forward pass is applied to a mini-batch of data.
- A loss is computed as the mean square error between the model output and the mini-batch ground truth.
- Finally an optimization step is performed: gradients are computed and variables from the
VarStoreare modified accordingly.
use tch::nn::{Module, OptimizerConfig};
use tch::{kind, nn, Device, Tensor};
fn my_module(p: nn::Path, dim: i64) -> impl nn::Module {
let x1 = p.zeros("x1", &[dim]);
let x2 = p.zeros("x2", &[dim]);
nn::func(move |xs| xs * &x1 + xs.exp() * &x2)
}
fn gradient_descent() {
let vs = nn::VarStore::new(Device::Cpu);
let my_module = my_module(vs.root(), 7);
let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
for _idx in 1..50 {
// Dummy mini-batches made of zeros.
let xs = Tensor::zeros(&[7], kind::FLOAT_CPU);
let ys = Tensor::zeros(&[7], kind::FLOAT_CPU);
let loss = (my_module.forward(&xs) - ys).pow_tensor_scalar(2).sum(kind::Kind::Float);
opt.backward_step(&loss);
}
}Writing a Simple Neural Network
The nn api can be used to create…
Excerpt shown — open the source for the full document.
Notability
notability 1.0/10Routine fork with minimal community interest.