JAX

JAX

Google推出的用于变换数值函数的机器学习框架

JAX是什么

JAX是Google推出的高性能数值计算库,提供类似NumPy的API,支持GPU/TPU加速、自动微分、即时编译(JIT)和向量化等功能。JAX通过XLA(加速线性代数)编译器优化代码,显著提升运行效率,在大规模数据处理和机器学习中表现突出。JAX支持自动微分,能轻松计算函数梯度,适用于优化算法。JAX的异步执行模式和不可变数组设计使其在性能和可靠性上优于传统NumPy,是现代科学计算和机器学习研究中的重要工具。

JAX

JAX的主要功能

  • 自动微分:通过jax.grad等函数自动计算函数的梯度,支持高阶导数,广泛应用在机器学习中的模型训练。
  • 即时编译(JIT):用jax.jit将Python函数编译成优化后的机器代码,显著提升运行效率,在大规模计算中效果显著。
  • 向量化:通过jax.vmap自动将函数向量化,避免手动循环,提高代码效率和可读性。
  • 并行化:用jax.pmap支持跨多个设备(如GPU、TPU)的并行计算,加速大规模任务处理。
  • 硬件加速:支持在CPU、GPU和TPU上运行代码,充分利用硬件的并行计算能力。
  • 程序变换:提供丰富的程序变换工具,如jax.lax,用在构建更复杂的程序逻辑,提升代码灵活性和扩展性。

如何使用JAX

  • 环境配置与安装
    • 创建Python环境:用conda创建一个专用的Python环境。
conda create <span class="token parameter variable">-n</span> jax_test <span class="token assign-left variable">python</span><span class="token operator">=</span><span class="token number">3.13</span> <span class="token parameter variable">-y</span>conda activate jax_test
    • 安装JAX库:根据硬件配置选择合适的JAX版本。
pip <span class="token function">install</span> jupyter numpy <span class="token string">"jax[cuda12]"</span> matplotlib pillow
  • 使用JAX的主要功能
    • 自动微分:使用jax.grad自动计算函数的梯度。
<span class="token keyword">import</span> jax<span class="token keyword">import</span> jax<span class="token punctuation">.</span>numpy <span class="token keyword">as</span> jnp<span class="token keyword">def</span> <span class="token function">cubic_sum</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">:</span>    <span class="token keyword">return</span> jnp<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>x<span class="token operator">**</span><span class="token number">3</span><span class="token punctuation">)</span>grad_cubic_sum <span class="token operator">=</span> jax<span class="token punctuation">.</span>grad<span class="token punctuation">(</span>cubic_sum<span class="token punctuation">)</span>x_input <span class="token operator">=</span> jnp<span class="token punctuation">.</span>arange<span class="token punctuation">(</span><span class="token number">1.0</span><span class="token punctuation">,</span> <span class="token number">5.0</span><span class="token punctuation">)</span>gradient <span class="token operator">=</span> grad_cubic_sum<span class="token punctuation">(</span>x_input<span class="token punctuation">)</span><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"梯度 df/dx:"</span><span class="token punctuation">,</span> gradient<span class="token punctuation">)</span>
    • 即时编译(JIT):用jax.jit将函数编译成优化后的机器代码。
<span class="token decorator annotation punctuation">@jax<span class="token punctuation">.</span>jit</span><span class="token keyword">def</span> <span class="token function">selu_jax_jit</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">:</span>    <span class="token keyword">return</span> <span class="token number">1.0507</span> <span class="token operator">*</span> jnp<span class="token punctuation">.</span>where<span class="token punctuation">(</span>x <span class="token operator">></span> <span class="token number">0</span><span class="token punctuation">,</span> x<span class="token punctuation">,</span> <span class="token number">1.67326</span> <span class="token operator">*</span> jnp<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1.67326</span><span class="token punctuation">)</span>x_jax <span class="token operator">=</span> jnp<span class="token punctuation">.</span>random<span class="token punctuation">.</span>normal<span class="token punctuation">(</span>jax<span class="token punctuation">.</span>random<span class="token punctuation">.</span>PRNGKey<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">10000</span><span class="token punctuation">,</span> <span class="token number">10000</span><span class="token punctuation">)</span><span class="token punctuation">)</span>result_jax_jit <span class="token operator">=</span> selu_jax_jit<span class="token punctuation">(</span>x_jax<span class="token punctuation">)</span>
    • 向量化:使用jax.vmap自动将函数向量化。
<span class="token keyword">def</span> <span class="token function">mat_vec_product</span><span class="token punctuation">(</span>matrix<span class="token punctuation">,</span> vector<span class="token punctuation">)</span><span class="token punctuation">:</span>    <span class="token keyword">return</span> jnp<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>matrix<span class="token punctuation">,</span> vector<span class="token punctuation">)</span>batched_mat_vec <span class="token operator">=</span> jax<span class="token punctuation">.</span>vmap<span class="token punctuation">(</span>mat_vec_product<span class="token punctuation">,</span> in_axes<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>matrix_jax <span class="token operator">=</span> jnp<span class="token punctuation">.</span>random<span class="token punctuation">.</span>normal<span class="token punctuation">(</span>jax<span class="token punctuation">.</span>random<span class="token punctuation">.</span>PRNGKey<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">10000</span><span class="token punctuation">,</span> <span class="token number">10000</span><span class="token punctuation">)</span><span class="token punctuation">)</span>vectors_jax <span class="token operator">=</span> jnp<span class="token punctuation">.</span>random<span class="token punctuation">.</span>normal<span class="token punctuation">(</span>jax<span class="token punctuation">.</span>random<span class="token punctuation">.</span>PRNGKey<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">10000</span><span class="token punctuation">)</span><span class="token punctuation">)</span>result_vmap <span class="token operator">=</span> batched_mat_vec<span class="token punctuation">(</span>matrix_jax<span class="token punctuation">,</span> vectors_jax<span class="token punctuation">)</spa

JAX的应用场景

  • 机器学习和深度学习:JAX的自动微分和硬件加速功能,能高效训练和推理神经网络,提升模型性能。
  • 科学计算:JAX能计算复杂物理方程导数,优化物理、化学和材料科学中的模拟和预测。
  • 数据分析和处理:借助向量化和并行化,JAX能快速处理大规模数据,适用图像、信号处理等领域。
  • 金融建模:用在金融风险评估和高频交易,高效计算助力实时数据分析和决策。
  • 计算生物学:处理基因组数据、预测蛋白质结构,加速生物医学研究和应用。