<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.3.3">Jekyll</generator><link href="https://astanziola.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="https://astanziola.github.io/" rel="alternate" type="text/html" hreflang="en"/><updated>2024-07-29T07:32:28+00:00</updated><id>https://astanziola.github.io/feed.xml</id><title type="html">blank</title><subtitle>Personal Website </subtitle><entry><title type="html">Acoustic Hologram Optimisation Using Automatic Differentiation in JAX</title><link href="https://astanziola.github.io/blog/2022/diff-pat-in-jax/" rel="alternate" type="text/html" title="Acoustic Hologram Optimisation Using Automatic Differentiation in JAX"/><published>2022-08-01T17:00:00+00:00</published><updated>2022-08-01T17:00:00+00:00</updated><id>https://astanziola.github.io/blog/2022/diff-pat-in-jax</id><content type="html" xml:base="https://astanziola.github.io/blog/2022/diff-pat-in-jax/"><![CDATA[<p>On December 2020, Tatsuki Fushimi, Kenta Yamamoto &amp; Yoichi Ochiai have <a href="https://arxiv.org/abs/2012.02431">submitted a preprint</a>, later accepted by <a href="https://www.nature.com/articles/s41598-021-91880-2">Scientific Reports</a>, on using <strong>automatic differentiation</strong> for the optimization of acoustic holograms produced by phased arrays.</p> <p>In the following, we will implement the main algorithm discussed in the paper using <a href="https://jax.readthedocs.io/en/latest/">jax</a>: this work provides a good case study for demonstrating the use of <a href="https://jax.readthedocs.io/en/latest/">jax</a> for scientific applications not related to machine learning.</p> <h2 id="problem-setup">Problem setup</h2> <p>Assume to have a transducer located in \(x_t\), which is transmitting a monochromatic (single frequency) signal with wavenumber</p> \[k = \frac{2\pi f_0}{c_0},\] <p>where \(f_0\) is the transmit frequency and \(c_0\) is the speed of sound of the homogeneous medium. Then one can use a simplified version of the <strong>Rayleigh integral</strong> (see <a href="https://pubmed.ncbi.nlm.nih.gov/26428789/">[Sapozhnikov et al., 2015]</a> for a more general discussion) to calculate the pressure field at a location \(x_c\)</p> \[p_{c,t} = \frac{P_{ref}}{\|x_c - x_t\|}D(\theta)e^{j(k\|x_t-x_c\|+ \phi_t)}\] <p>where \(P_{ref}\) is the pressure amplitude at the transducer, \(\phi_t\) is the phase of the transmit wave and</p> \[D(\theta) = \frac{2J_1(k r \sin(\theta))}{k r \sin(\theta)}\] <p>is the directivity factor which depends on the angle \(\theta\) between the transducer normal and the vector \(x_t-x_c\). Here, the function \(J_1\) is the Bessel function of the first kind of order 1.</p> <p>Note that the pressure is expressed as a <strong>complex number</strong>, as it is customary for time harmonic fields, in order to implicitly define the phase relationships between the field at various locations.</p> <p>The last step is to use the <strong>superposition</strong> property deriving from the <a href="https://en.wikipedia.org/wiki/Superposition_principle#Wave_superposition">linearity of the wave equation</a> to sum the contribution of \(M\) transducers in the phased array to the field</p> \[p(x_c) = \sum_{t=1}^{M} p_{c,t}\] <h3 id="optimization">Optimization</h3> <p>All is left to use automatic differentiation is to define a loss function. The authors have chosen to optimize the <strong>amplitude</strong> \(\|p(x_c)\|\) of the field by matching it against some known positive field \(A(x_c)\). Using a squared error distance, this reduces the loss function to</p> \[\mathcal L(p) = \frac{1}{|\Omega|}\int_\Omega (A(x) - |p(x)|)^2 dx \propto \sum_{x_c \in X} (A(x) - |p(x)|)^2\] <p>for an appropriate dense set of positions \(X\), which we will take as equispaced points (i.e. pixels) to directly compare the field with a digital image.</p> <h2 id="implementation">Implementation</h2> <p>First of all, let’s import the required libraries:</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="n">jax</span>
<span class="kn">from</span> <span class="n">jax</span> <span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">jnp</span>
<span class="kn">from</span> <span class="n">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span></code></pre></figure> <p>Afterwards, we define some parameters that we will use throughout the following sections.</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="c1"># Free parameters
</span><span class="n">c0</span> <span class="o">=</span> <span class="mi">346</span>                <span class="c1"># Air speed of sound
</span><span class="n">f0</span> <span class="o">=</span> <span class="mf">40e3</span>               <span class="c1"># Transmit frequency
</span><span class="n">radius</span> <span class="o">=</span> <span class="mf">0.005</span>          <span class="c1"># Radius of the transducer
</span><span class="n">z_plane</span> <span class="o">=</span> <span class="mf">0.1</span>           <span class="c1"># Distance between transducer and target field
</span><span class="n">k</span> <span class="o">=</span> <span class="mi">2</span><span class="o">*</span><span class="n">jnp</span><span class="p">.</span><span class="n">pi</span><span class="o">*</span><span class="n">f0</span><span class="o">/</span><span class="n">c0</span>      <span class="c1"># Wavenumber
</span>
<span class="c1"># Generating transducer positions
</span><span class="n">x_pos</span><span class="p">,</span> <span class="n">y_pos</span> <span class="o">=</span> <span class="nf">map</span><span class="p">(</span>
    <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="p">(</span><span class="n">x</span><span class="o">-</span><span class="n">jnp</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">x</span><span class="p">)).</span><span class="nf">flatten</span><span class="p">()</span><span class="o">*</span><span class="n">radius</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">indices</span><span class="p">((</span><span class="mi">32</span><span class="p">,</span><span class="mi">32</span><span class="p">))</span>
<span class="p">)</span>
<span class="n">z_pos</span> <span class="o">=</span> <span class="n">x_pos</span><span class="o">*</span><span class="mi">0</span>         <span class="c1"># Tranducers on x_y plane, z = 0
</span><span class="n">positions</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">x_pos</span><span class="p">,</span> <span class="n">y_pos</span><span class="p">,</span> <span class="n">z_pos</span><span class="p">],</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># Evaluating normals
</span><span class="n">normals</span> <span class="o">=</span> <span class="p">(</span><span class="n">positions</span><span class="o">*</span><span class="mi">0</span><span class="p">).</span><span class="n">at</span><span class="p">[:,</span><span class="mi">2</span><span class="p">].</span><span class="nf">set</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span> <span class="c1"># All normals along the z axis -&gt; (0,0,1)
</span>
<span class="c1"># Initializing phases to zero
</span><span class="n">phases</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="n">positions</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],))</span>

<span class="c1"># Sampling positions at the target plane
</span><span class="n">x_pos</span><span class="p">,</span> <span class="n">y_pos</span> <span class="o">=</span> <span class="nf">map</span><span class="p">(</span>
    <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="p">(</span><span class="n">x</span><span class="o">-</span><span class="n">jnp</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">x</span><span class="p">)).</span><span class="nf">flatten</span><span class="p">()</span><span class="o">*</span><span class="n">radius</span><span class="o">/</span><span class="mi">8</span><span class="p">,</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">indices</span><span class="p">((</span><span class="mi">256</span><span class="p">,</span><span class="mi">256</span><span class="p">))</span>
<span class="p">)</span>
<span class="n">z_pos</span> <span class="o">=</span> <span class="n">x_pos</span><span class="o">*</span><span class="mi">0</span> <span class="o">+</span> <span class="n">z_plane</span>  <span class="c1"># Plane parallel to the transducers array
</span><span class="n">plane_positions</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">x_pos</span><span class="p">,</span> <span class="n">y_pos</span><span class="p">,</span> <span class="n">z_pos</span><span class="p">],</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># Some helper functions
</span><span class="n">norm</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">abs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">))</span>
<span class="n">avg_norm</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">abs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">))</span>
<span class="n">dot</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">x</span><span class="o">*</span><span class="n">y</span><span class="p">)</span></code></pre></figure> <h3 id="forward-functions">Forward functions</h3> <p>We can start implementing some functions! While doing that, we can focus on a single transducer and a single target point, as we will parallelize (actually, <strong>vectorize</strong>) everything later on using <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap"><code class="language-plaintext highlighter-rouge">jax.vmap</code></a>.</p> <p>The first function evaluates the angle \(\theta\) between a transducer normal and a target location, simply by using the arc-cosine of their normalized dot product:</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">angle_between</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">arccos</span><span class="p">(</span><span class="nf">dot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="nf">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">*</span><span class="nf">norm</span><span class="p">(</span><span class="n">y</span><span class="p">)))</span></code></pre></figure> <p>At this point, we hit the first problem: we need the Bessel function of the first kind of order 1 to evaluate the directivity factor, but it looks like JAX doesn’t have it! However, <a href="https://en.wikipedia.org/wiki/Bessel_function">according to Wikipedia</a> it holds the relationship</p> \[-J_1(x) = \frac{\partial}{\partial x} J_0(x)\] <p>where \(J_0(x)\) is the Bessel function of the first kind of order 0, which is implemented by jax in <code class="language-plaintext highlighter-rouge">jax.numpy.i0</code>. So we can implement the Bessel function of the first kind using autodiff with <code class="language-plaintext highlighter-rouge">jax.grad</code> (I’m 99% sure this is correct, but I’m not sure):</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">J1</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="n">J0</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">i0</span><span class="p">(</span><span class="o">-</span><span class="mf">1j</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">jax</span><span class="p">.</span><span class="nf">grad</span><span class="p">(</span><span class="n">J0</span><span class="p">)(</span><span class="n">x</span><span class="p">).</span><span class="n">real</span></code></pre></figure> <p>At this point, we can write the directivity function \(D(\theta)\) as</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">directivity_fun</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">k</span><span class="o">*</span><span class="n">radius</span><span class="o">*</span><span class="n">jnp</span><span class="p">.</span><span class="nf">sin</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
    <span class="n">D_with_nans</span> <span class="o">=</span> <span class="mi">2</span><span class="o">*</span><span class="nc">J1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">/</span><span class="n">x</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">isnan</span><span class="p">(</span><span class="n">D_with_nans</span><span class="p">),</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">D_with_nans</span><span class="p">)</span></code></pre></figure> <p>Having all the main ingredients setup, we can finally write the function that evaluates the beam-pattern of a single transducer (at a single location)</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">p_c</span><span class="p">(</span><span class="n">x_c</span><span class="p">,</span> <span class="n">x_t</span><span class="p">,</span> <span class="n">normal_vec</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
    <span class="n">theta</span> <span class="o">=</span> <span class="nf">angle_between</span><span class="p">(</span><span class="n">x_c</span><span class="p">,</span> <span class="n">normal_vec</span><span class="p">)</span>
    <span class="n">D</span> <span class="o">=</span> <span class="nf">directivity_fun</span><span class="p">(</span><span class="n">theta</span><span class="p">)</span>
    <span class="n">dist</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">x_c</span> <span class="o">-</span> <span class="n">x_t</span><span class="p">)</span>
    <span class="n">output_phase</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="mf">1j</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">*</span><span class="n">dist</span> <span class="o">+</span> <span class="n">phase</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">D</span><span class="o">*</span><span class="n">output_phase</span><span class="o">/</span><span class="n">dist</span></code></pre></figure> <h3 id="vectorization">Vectorization</h3> <p>Adding the contribution of all the transducers can be easily done by vectorizing the function above with respect to the input positions, using <code class="language-plaintext highlighter-rouge">jax.vmap</code>, and summing them all:</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">p_tot</span><span class="p">(</span><span class="n">xc</span><span class="p">,</span> <span class="n">xt</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">phases</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span>
        <span class="n">jax</span><span class="p">.</span><span class="nf">vmap</span><span class="p">(</span><span class="n">p_c</span><span class="p">,</span> <span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">0</span><span class="p">)(</span><span class="n">xc</span><span class="p">,</span> <span class="n">xt</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">phases</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span></code></pre></figure> <p>Similarly, we can get the field at all the positions by vectorizing the function above with respect to the target location <code class="language-plaintext highlighter-rouge">x_c</code></p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">pc_vect</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="nf">vmap</span><span class="p">(</span><span class="n">p_tot</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="bp">None</span><span class="p">,</span><span class="bp">None</span><span class="p">,</span><span class="bp">None</span><span class="p">),</span><span class="mi">0</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">get_hologram</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="nf">pc_vect</span><span class="p">(</span><span class="n">plane_positions</span><span class="p">,</span> <span class="n">positions</span><span class="p">,</span><span class="n">normals</span><span class="p">,</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span><span class="mi">256</span><span class="p">))</span></code></pre></figure> <p>Let’s look at the hologram for the initial, flat phase distribution</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">p</span> <span class="o">=</span> <span class="nf">get_hologram</span><span class="p">(</span><span class="n">phases</span><span class="p">)</span></code></pre></figure> <div class="row mt-3"> <div class="col-sm mt-3 mt-md-0"> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/img/diff_pat/flat_beampattern-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/img/diff_pat/flat_beampattern-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/img/diff_pat/flat_beampattern-1400.webp"/> <img src="/assets/img/diff_pat/flat_beampattern.png" class="img-fluid rounded z-depth-1" width="auto" height="auto" data-zoomable="" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> </div> </div> <h2 id="loss-function">Loss function</h2> <p>Let’s start with a very simple image that we are trying to match</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="n">jax.example_libraries</span> <span class="kn">import</span> <span class="n">optimizers</span>
<span class="kn">from</span> <span class="n">jax</span> <span class="kn">import</span> <span class="n">random</span>

<span class="c1"># Constructing reference image
</span><span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="mi">256</span><span class="p">,</span><span class="mi">256</span><span class="p">))</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">32</span><span class="p">:</span><span class="mi">64</span><span class="p">,</span><span class="mi">48</span><span class="p">:</span><span class="mi">164</span><span class="p">].</span><span class="nf">set</span><span class="p">(.</span><span class="mi">5</span><span class="p">)</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">64</span><span class="p">:</span><span class="mi">200</span><span class="p">,</span><span class="mi">128</span><span class="p">:</span><span class="mi">164</span><span class="p">].</span><span class="nf">set</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">96</span><span class="p">:</span><span class="mi">128</span><span class="p">,</span><span class="mi">164</span><span class="p">:</span><span class="mi">150</span><span class="p">].</span><span class="nf">set</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">150</span><span class="p">:</span><span class="mi">210</span><span class="p">,</span><span class="mi">128</span><span class="p">:</span><span class="mi">164</span><span class="p">].</span><span class="nf">set</span><span class="p">(.</span><span class="mi">3</span><span class="p">)</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">150</span><span class="p">:</span><span class="mi">190</span><span class="p">,</span><span class="mi">64</span><span class="p">:</span><span class="mi">100</span><span class="p">].</span><span class="nf">set</span><span class="p">(.</span><span class="mi">7</span><span class="p">)</span>
<span class="n">reference_hologram</span> <span class="o">=</span> <span class="n">reference_hologram</span><span class="p">.</span><span class="n">at</span><span class="p">[</span><span class="mi">200</span><span class="p">:</span><span class="mi">230</span><span class="p">,</span><span class="mi">64</span><span class="p">:</span><span class="mi">200</span><span class="p">].</span><span class="nf">set</span><span class="p">(.</span><span class="mi">2</span><span class="p">)</span></code></pre></figure> <div class="row mt-3"> <div class="col-sm mt-3 mt-md-0"> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/img/diff_pat/reference_pattern-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/img/diff_pat/reference_pattern-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/img/diff_pat/reference_pattern-1400.webp"/> <img src="/assets/img/diff_pat/reference_pattern.png" class="img-fluid rounded z-depth-1" width="auto" height="auto" data-zoomable="" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> </div> </div> <p>We are going to optimize a slightly different loss function than the Diff-PAT paper, namely the cross correlation, defined as</p> \[\mathcal L = \sum_i A(x_i) |p(x_i)|\] <p>which is implemented as</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">xcorr</span><span class="p">(</span><span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">):</span>
    <span class="k">return</span> <span class="o">-</span><span class="nf">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="nf">norm</span><span class="p">(</span><span class="n">a</span><span class="p">)</span><span class="o">*</span><span class="nf">norm</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>

<span class="k">def</span> <span class="nf">lossfun</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="k">return</span> <span class="nf">xcorr</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nf">abs</span><span class="p">(</span><span class="nf">get_hologram</span><span class="p">(</span><span class="n">x</span><span class="p">)),</span> <span class="n">reference_hologram</span><span class="p">)</span></code></pre></figure> <p>Note that the loss function depends on the vector of phases for each transducer. To optimize it, we get the gradient using autodiff</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">loss_with_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="nf">value_and_grad</span><span class="p">(</span><span class="n">lossfun</span><span class="p">)</span></code></pre></figure> <h2 id="optimization-1">Optimization</h2> <p>We are now all setup to optimize the loss function. All we need is an updated function that takes the current vector of phases and updates it using the gradient. We will use the Adam optimizer, as in the Diff-PAT paper.</p> <p>As it is customary in JAX, we can use the <code class="language-plaintext highlighter-rouge">jax.jit</code> to just-in-time compile this function for faster execution.</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="c1"># Initialize optimizer
</span><span class="n">init_fun</span><span class="p">,</span> <span class="n">update_fun</span><span class="p">,</span> <span class="n">get_params</span> <span class="o">=</span> <span class="n">optimizers</span><span class="p">.</span><span class="nf">adam</span><span class="p">(.</span><span class="mi">2</span><span class="p">)</span>
<span class="n">opt_state</span> <span class="o">=</span> <span class="nf">init_fun</span><span class="p">(</span><span class="n">phases</span><span class="p">)</span>

<span class="nd">@jax.jit</span>
<span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="n">opt_state</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">iteration</span><span class="p">):</span>
    <span class="n">params</span> <span class="o">=</span> <span class="nf">get_params</span><span class="p">(</span><span class="n">opt_state</span><span class="p">)</span>
    <span class="n">lossval</span><span class="p">,</span> <span class="n">gradient</span> <span class="o">=</span> <span class="nf">loss_with_grad</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">lossval</span><span class="p">,</span> <span class="nf">update_fun</span><span class="p">(</span><span class="n">iteration</span><span class="p">,</span> <span class="n">gradient</span><span class="p">,</span> <span class="n">opt_state</span><span class="p">)</span></code></pre></figure> <p>All is left to do now is to wrap the <code class="language-plaintext highlighter-rouge">update</code> function in a loop that runs for a number of iterations. Note that we explicitly define a random seed for the random number generator, since this aids reproducibility and is anyhow necessary in JAX.</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">losshistory</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="nc">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>

<span class="k">for</span> <span class="n">iteration</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
    <span class="n">_</span><span class="p">,</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="nf">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
    <span class="n">lossval</span><span class="p">,</span> <span class="n">opt_state</span> <span class="o">=</span> <span class="nf">update</span><span class="p">(</span><span class="n">opt_state</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">iteration</span><span class="p">)</span>
    
    <span class="c1"># For logging
</span>    <span class="n">losshistory</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="o">-</span><span class="n">lossval</span><span class="p">)</span></code></pre></figure> <h2 id="results">Results</h2> <p>After the optimization is over, which should be relatively fast especially if you are running <code class="language-plaintext highlighter-rouge">jax</code> on a GPU, we can visualize the results:</p> <div class="row mt-3"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/img/diff_pat/optimized_beampattern-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/img/diff_pat/optimized_beampattern-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/img/diff_pat/optimized_beampattern-1400.webp"/> <img src="/assets/img/diff_pat/optimized_beampattern.png" class="img-fluid rounded z-depth-1" width="auto" height="auto" data-zoomable="" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/img/diff_pat/phase_encoding-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/img/diff_pat/phase_encoding-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/img/diff_pat/phase_encoding-1400.webp"/> <img src="/assets/img/diff_pat/phase_encoding.png" class="img-fluid rounded z-depth-1" width="auto" height="auto" data-zoomable="" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/img/diff_pat/corr_coefficient-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/img/diff_pat/corr_coefficient-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/img/diff_pat/corr_coefficient-1400.webp"/> <img src="/assets/img/diff_pat/corr_coefficient.png" class="img-fluid rounded z-depth-1" width="auto" height="auto" data-zoomable="" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> <p>It is fairly close to the target hologram, but not quite. One could experiment with different loss functions, or with different initial phases. Note however that we are currently only controlling the phase of the transducers. If one could also control the amplitude, than the wave propagator is a <strong>linear operator</strong> of the complex input parameters \(P_{t}e^{j\theta_t}\), making the MSE optimization problem convex and therefore uniquely solvable (up to a global phase shift).</p> <h2 id="conclusions">Conclusions</h2> <p>In this tutorial, we have reproduced the Diff-PAT algorithm, and we have shown how JAX can be used to easily and efficiently prototype algorithms that are relevant for numerical physics methods, by exploiting its ability to conveniently transform functions in several ways.</p> <p>A jupyter notebook implementing this tutorial can be found <a href="https://github.com/astanziola/diff-pat-jax">at the following GitHub repo</a>.</p> <p>The findings from Fushimi et al. could also be extended in a number of ways. For example, the hologram produced by a planar wavefront could be efficiently propagated in the Fourier domain: this is implemented in the <a href="https://github.com/ucl-bug/jwave/blob/e8884856b0cf88c5fe7ede5e003d98143c8973e5/jwave/acoustics/time_harmonic.py#L17"><code class="language-plaintext highlighter-rouge">angular_spectrum</code></a> function of the <code class="language-plaintext highlighter-rouge">jwave</code> package.</p> <h3 id="references">References</h3> <ul> <li>Fushimi, T., Yamamoto, K. &amp; Ochiai, Y. <strong>Acoustic hologram optimisation using automatic differentiation.</strong> Sci Rep 11, 12678 (2021). <a href="https://www.nature.com/articles/s41598-021-91880-2">https://doi.org/10.1038/s41598-021-91880-2</a></li> </ul>]]></content><author><name></name></author><category term="code-examples"/><category term="implementation"/><category term="acoustics"/><category term="jax"/><summary type="html"><![CDATA[efficient implementation of the Fushimi at al. using jax]]></summary></entry></feed>