Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

mstrand1/Jax-logistic-regression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

History

12 Commits

Repository files navigation

Jax-logistic-regression

Logistic regression classifier using Google's JAX to support GPU acceleration.

This class is an update of a logistic regression class used in my intro to machine learning course. The major difference is the handling of the gradient descent operations, which were rewritten using jax's grad, jit, and vmap functions. The goal with this project is speed - I've found that using JaxReg with GPU acceleration gives a ~29x speed increase over the original class. I used Google colab's free GPU when measuring speed increase (see 'Time Comparison').

AltStyle によって変換されたページ (->オリジナル) /