ForkNous ResearchNous Researchpublished Sep 11, 2024seen 5d

NousResearch/tch-rs

forked from LaurentMazare/tch-rs

Open original ↗

Captured source

source ↗
published Sep 11, 2024seen 5dcaptured 9hhttp 200method plain

NousResearch/tch-rs

Description: Rust bindings for the C++ api of PyTorch.

Language: Rust

License: Apache-2.0

Stars: 6

Forks: 2

Open issues: 1

Created: 2024-09-11T18:35:57Z

Pushed: 2026-03-10T15:40:39Z

Default branch: main

Fork: yes

Parent repository: LaurentMazare/tch-rs

Archived: no

README:

tch-rs

Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

![Build Status](https://github.com/LaurentMazare/tch-rs/actions) ![Documentation](https://docs.rs/tch) ![Dependency Status](https://deps.rs/repo/github/LaurentMazare/tch-rs) changelog

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ PyTorch library (libtorch) in version *v2.7.0* to be available on your system. You can either:

  • Use the system-wide libtorch installation (default).
  • Install libtorch manually and let the build script know about it via the LIBTORCH environment variable.
  • Use a Python PyTorch install, to do this set LIBTORCH_USE_PYTORCH=1.
  • When a system-wide libtorch can't be found and LIBTORCH is not set, the

build script can download a pre-built binary version of libtorch by using the download-libtorch feature. By default a CPU version is used. The TORCH_CUDA_VERSION environment variable can be set to cu117 in order to get a pre-built binary using CUDA 11.7.

System-wide Libtorch

On linux platforms, the build script will look for a system-wide libtorch library in /usr/lib/libtorch.so.

Python PyTorch Install

If the LIBTORCH_USE_PYTORCH environment variable is set, the active python interpreter is called to retrieve information about the torch python package. This version is then linked against.

Libtorch Manual Install

  • Get libtorch from the

PyTorch website download section and extract the content of the zip file.

  • For Linux and macOS users, add the following to your .bashrc or equivalent, where /path/to/libtorch

is the path to the directory that was created when unzipping the file.

export LIBTORCH=/path/to/libtorch

The header files location can also be specified separately from the shared library via the following:

# LIBTORCH_INCLUDE must contain `include` directory.
export LIBTORCH_INCLUDE=/path/to/libtorch/
# LIBTORCH_LIB must contain `lib` directory.
export LIBTORCH_LIB=/path/to/libtorch/
  • For Windows users, assuming that X:\path\to\libtorch is the unzipped libtorch directory.
  • Navigate to Control Panel -> View advanced system settings -> Environment variables.
  • Create the LIBTORCH variable and set it to X:\path\to\libtorch.
  • Append X:\path\to\libtorch\lib to the Path variable.

If you prefer to temporarily set environment variables, in PowerShell you can run

$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
  • You should now be able to run some examples, e.g. cargo run --example basics.

Windows Specific Notes

As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.

It is recommended to use the MSVC Rust toolchain (e.g. by installing stable-x86_64-pc-windows-msvc via rustup) rather than a MinGW based one as PyTorch has compatibilities issues with MinGW.

Static Linking

When setting environment variable LIBTORCH_STATIC=1, libtorch is statically linked rather than using the dynamic libraries. The pre-compiled artifacts don't seem to include libtorch.a by default so this would have to be compiled manually, e.g. via the following:

git clone -b v2.7.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
cd pytorch-static
USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build
# export LIBTORCH to point at the build directory in pytorch-static.

Examples

Basic Tensor Operations

This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.

use tch::Tensor;

fn main() {
let t = Tensor::from_slice(&[3, 1, 4, 1, 5]);
let t = t * 2;
t.print();
}

Training a Model via Gradient Descent

PyTorch provides automatic differentiation for most tensor operations it supports. This is commonly used to train models using gradient descent. The optimization is performed over variables which are created via a nn::VarStore by defining their shapes and initializations.

In the example below my_module uses two variables x1 and x2 which initial values are 0. The forward pass applied to tensor xs returns xs * x1 + exp(xs) * x2.

Once the model has been generated, a nn::Sgd optimizer is created. Then on each step of the training loop:

  • The forward pass is applied to a mini-batch of data.
  • A loss is computed as the mean square error between the model output and the mini-batch ground truth.
  • Finally an optimization step is performed: gradients are computed and variables from the VarStore are modified accordingly.
use tch::nn::{Module, OptimizerConfig};
use tch::{kind, nn, Device, Tensor};

fn my_module(p: nn::Path, dim: i64) -> impl nn::Module {
let x1 = p.zeros("x1", &[dim]);
let x2 = p.zeros("x2", &[dim]);
nn::func(move |xs| xs * &x1 + xs.exp() * &x2)
}

fn gradient_descent() {
let vs = nn::VarStore::new(Device::Cpu);
let my_module = my_module(vs.root(), 7);
let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
for _idx in 1..50 {
// Dummy mini-batches made of zeros.
let xs = Tensor::zeros(&[7], kind::FLOAT_CPU);
let ys = Tensor::zeros(&[7], kind::FLOAT_CPU);
let loss = (my_module.forward(&xs) - ys).pow_tensor_scalar(2).sum(kind::Kind::Float);
opt.backward_step(&loss);
}
}

Writing a Simple Neural Network

The nn api can be used to create…

Excerpt shown — open the source for the full document.

Notability

notability 1.0/10

Routine fork with minimal community interest.