Python - Different Regular/analytic Functions
To perform the derivative, I have developed the following code: import matplotlib.pyplot as plt import numpy as np from math import * xi = jnp.linspace(-3,3) def f(x): a = x**3
Solution 1:
As already mentioned in the comments, you can't use methods outside the jax library like scipy.stats.norm.cdf
. Use jax.scipy.stats
instead. Similarly, replace exp
and sqrt
with their jax equivalents jnp.exp
and jnp.sqrt
:
from jax import jit, grad, vmap
import jax.numpy as jnp
from jax.scipy.stats.norm import cdf
defanalytical_call(s0):
T, q, r, k, sigma = 1.0, 0.0, 0.0, 1.0, 0.4
Kt = k*jnp.exp((q-r)*T)
d = (jnp.log(Kt/s0)+(sigma**2)/2*T)/sigma
result = (Kt * cdf((d / jnp.sqrt(T)), 0.0, 1.0) - s0 * cdf(((d - sigma * T) / jnp.sqrt(T)), 0.0, 1.0) ) * jnp.exp(-q * T) + jnp.exp(-q * T) * (s0 - Kt)
return result
g = vmap(grad(analytical_call))
h = vmap(grad(grad(analytical_call)))
xi = jnp.linspace(1,1.5)
Then, you can evaluate g(xi)
and h(xi)
.
Post a Comment for "Python - Different Regular/analytic Functions"