JAX NumPy is designed to be highly compatible with NumPy. Most of the basic NumPy functions have their equivalents in JAX NumPy. For example, creating arrays, performing arithmetic operations, and accessing elements follow similar syntax in both libraries.
import numpy as np
import jax.numpy as jnp
# Create a NumPy array
np_array = np.array([1, 2, 3])
# Create a JAX NumPy array
jax_array = jnp.array([1, 2, 3])
print("NumPy array:", np_array)
print("JAX NumPy array:", jax_array)
JAX allows you to compile your Python functions using the jax.jit
decorator. This compiles the function to XLA, which can significantly speed up the execution of the function, especially for large arrays.
import jax
def add_arrays(x, y):
return x + y
jit_add_arrays = jax.jit(add_arrays)
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
result = jit_add_arrays(x, y)
print("Result after JIT compilation:", result)
One of the most powerful features of JAX is automatic differentiation. You can compute gradients of a function with respect to its inputs using jax.grad
.
def square(x):
return jnp.square(x)
grad_square = jax.grad(square)
x = jnp.array(3.0)
gradient = grad_square(x)
print("Gradient of square function at x = 3:", gradient)
JAX provides the jax.vmap
function for automatic vectorization. This allows you to transform a function that operates on single inputs to a function that operates on batches of inputs without writing explicit loops.
def add_one(x):
return x + 1
vmap_add_one = jax.vmap(add_one)
x_batch = jnp.array([1, 2, 3])
result_batch = vmap_add_one(x_batch)
print("Result after vectorization:", result_batch)
You can create JAX NumPy arrays in a similar way to NumPy arrays. You can use functions like jnp.array
, jnp.zeros
, jnp.ones
, etc.
# Create an array from a list
arr1 = jnp.array([1, 2, 3])
# Create a zero - filled array
arr2 = jnp.zeros((2, 3))
# Create an array filled with ones
arr3 = jnp.ones((3, 2))
print("Array from list:", arr1)
print("Zero - filled array:", arr2)
print("Ones - filled array:", arr3)
JAX NumPy supports a wide range of array operations, including arithmetic operations, matrix multiplication, and reshaping.
# Arithmetic operations
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
z = x + y
print("Sum of arrays:", z)
# Matrix multiplication
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.dot(A, B)
print("Matrix multiplication result:", C)
# Reshaping
arr = jnp.array([1, 2, 3, 4])
reshaped_arr = arr.reshape((2, 2))
print("Reshaped array:", reshaped_arr)
When you have a function that performs a series of numerical operations on large arrays, it is a good practice to use jax.jit
to compile the function. However, keep in mind that there is a compilation overhead, so it is most beneficial for functions that are called multiple times.
import time
def slow_function(x, y):
z = x * y
w = z + x
return w
jit_slow_function = jax.jit(slow_function)
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
# First call has compilation overhead
start_time = time.time()
jit_slow_function(x, y).block_until_ready()
first_call_time = time.time() - start_time
# Subsequent calls are faster
start_time = time.time()
jit_slow_function(x, y).block_until_ready()
second_call_time = time.time() - start_time
print("First call time:", first_call_time)
print("Second call time:", second_call_time)
Automatic differentiation in JAX makes it easy to implement gradient - based optimization algorithms. For example, you can implement a simple gradient descent algorithm to minimize a function.
def loss_function(params):
return jnp.square(params).sum()
grad_loss = jax.grad(loss_function)
params = jnp.array([1.0, 2.0, 3.0])
learning_rate = 0.1
for i in range(10):
gradient = grad_loss(params)
params = params - learning_rate * gradient
print(f"Step {i}: params = {params}, loss = {loss_function(params)}")
JIT - compiled functions should be pure functions, meaning they do not have any side effects such as modifying global variables or printing inside the function. This is because the JIT compiler may optimize the function in a way that side effects are not executed as expected.
# Bad practice
global_variable = 0
@jax.jit
def bad_function(x):
global global_variable
global_variable += 1
return x + 1
# Good practice
@jax.jit
def good_function(x):
return x + 1
When working with large arrays, be mindful of memory usage. JAX uses device memory (e.g., GPU memory), and running out of memory can lead to crashes. You can use techniques like in - place operations (although they are limited in JAX) and releasing unused arrays to manage memory effectively.
JAX NumPy offers a powerful and flexible way to perform numerical operations in Python. With its compatibility with NumPy, JIT compilation, automatic differentiation, and vectorization capabilities, it is a great choice for scientific computing and machine learning applications. By understanding the fundamental concepts, usage methods, common practices, and best practices, you can take full advantage of JAX NumPy and write efficient and high - performance code.