• Ubuntu下jax安装与使用


    目录

    安装说明

    pip安装

    conda安装

    参考网址


    注:该项目目前仍然没有官方的Windows支持,需要自己编译。

    安装说明

    该库安装时分为两部分:

    1. jaxlib,该库平台相关,目前没有官方的编译
    2. jax,该库依赖jaxlib,平台无关,可以直接安装。

    找到目前一个还活跃的jaxlib非官方编译服务:

    https://github.com/cloudhan/jax-windows-builder

    pip安装

    要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行

    pip install --upgrade pip
    pip install --upgrade "jax[cpu]"

    在 Linux 上,通常需要先更新pip到支持 manylinux2014轮子的版本。 这些pip安装不适用于 Windows,并且可能会静默失败;见 上文

    如果要安装同时支持 CPU 和 NVidia GPU 的 JAX,则必须首先安装CUDA和 CuDNN(如果尚未安装)。与其他一些流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为pip 软件包的一部分。

    JAX仅为 Linux提供预构建的 CUDA 兼容轮子,带有 CUDA 11.1 或更高版本,以及 CuDNN 8.0.5 或更高版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建

    • 需要CUDA 11.1 或更新版本。
    • 预建轮子支持的 cuDNN 版本是:
      • cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 轮,因为它支持附加功能。
      • cuDNN 8.0.5 或更高版本。
    • 必须使用至少与您的 CUDA 工具包的相应驱动程序版本一样新的 NVidia 驱动程序版本。例如,如果您安装了 CUDA 11.4 update 4,则在 Linux 上必须使用 NVidia 驱动程序 470.82.01 或更新版本。这是一个严格的要求,因为 JAX 依赖于 JIT 编译代码;较旧的驱动程序可能会导致故障。
      • 如果您需要将较新的 CUDA 工具包与较旧的驱动程序一起使用,例如在无法轻松更新 NVidia 驱动程序的集群上,您可以使用 NVidia 为此目的提供的CUDA 前向兼容性包。

    接下来,运行

    pip install --upgrade pip
    # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
    # Note: wheels only available on linux.
    pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

    这些pip安装不适用于 Windows,并且可能会静默失败;

    jaxlib 版本必须与您要使用的现有 CUDA 安装的版本相对应。您可以为 jaxlib 显式指定特定的 CUDA 和 CuDNN 版本:

    pip install --upgrade pip
    
    # Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
    pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    
    # Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
    pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    
    具体版本
    pip install --upgrade jax==0.3.15 jaxlib==0.3.15+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    conda安装

    有一个社区支持的 Conda 构建jax。要安装 using conda,只需运行

    conda install jax -c conda-forge

    要在具有 NVidia GPU 的机器上安装,请运行

    conda install jax cuda-nvcc -c conda-forge -c nvidia

    请注意cudatoolkitDistributed by conda-forgeis missing ptxas,这是 JAX 要求的。因此,您必须cuda-nvcc从频道安装软件包nvidia,或者在您的机器上单独安装 CUDA,以便ptxas 在您的路径中。上面的频道顺序很重要(conda-forge之前 nvidia)。我们正在努力简化这一点。

    如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 网站提示和技巧 部分中的说明进行操作conda-forge

    参考网址

    https://github.com/google/jax

  • 相关阅读:
    Framework 到底该怎么学习?
    7. Go的map
    AD20~PCB的板层设计和布线
    react native 使用阿里字体图标库
    OpenCV [c++](图像处理基础示例程序汇总)
    frida打印byte数组
    设计模式(2) - 创建型模式
    nginx进程间同步机制-互斥锁
    【Python GUI编程】零基础也能轻松掌握的学习路线与参考资料
    C# 继承,抽象,接口,泛型约束,扩展方法
  • 原文地址:https://blog.csdn.net/zaf0516/article/details/126390534