1.3 万 Star!
JAX 是机器学习 (ML) 领域的新生力量,它有望使 ML 编程更加直观、结构化和简洁。
在机器学习领域,大家可能对 和 已经耳熟能详,但除了这两个框架,一些新生力量也不容小觑,它就是谷歌推出的 JAX。很对研究者对其寄予厚望,希望它可以取代 等众多机器学习框架。
JAX 最初由谷歌大脑团队的 Matt 、Roy 、 和 Chris Leary 等人发起。
目前,JAX 在 上已累积 13.7K 星。
项目地址:
迅速发展的 JAX
JAX 的前身是 ,其借助 的更新版本,并且结合了 XLA,可对 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。
开发 JAX 的出发点是什么?说到这,就不得不提 NumPy。NumPy 是 中的一个基础数值运算库,被广泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也没有对反向传播的内置支持,此外, 本身的速度限制阻碍了 NumPy 使用,所以少有研究者在生产环境下直接用 numpy 训练或部署深度学习模型。
在此情况下,出现了众多的深度学习框架,如 、 等。但是 numpy 具有灵活、调试方便、API 稳定等独特的优势。而 JAX 的主要出发点就是将 numpy 的以上优势与硬件加速结合。
目前,基于 JAX 已有很多优秀的开源项目,如谷歌的神经网络库团队开发了 Haiku,这是一个面向 Jax 的深度学习代码库,通过 Haiku,用户可以在 Jax 上进行面向对象开发;又比如 RLax,这是一个基于 Jax 的强化学习库,用户使用 RLax 就能进行 Q- 模型的搭建和训练;此外还包括基于 JAX 的深度学习库 ,该库一行代码就能定义计算图、可进行 GPU 加速。可以说,在过去几年中,JAX 掀起了深度学习研究的