Skip to content

Instantly share code, notes, and snippets.

@plampite
Last active November 6, 2024 04:25
Show Gist options
  • Save plampite/0857fdc5efcb4986da56f56205326059 to your computer and use it in GitHub Desktop.
Save plampite/0857fdc5efcb4986da56f56205326059 to your computer and use it in GitHub Desktop.
A simple, multi-dimensional linear interpolation function in Fortran. Roughly inspired by FINT in Cernlib, but hopefully more readable and easy to understand.
!Copyright 2024 Paolo Lampitella
!This code is licensed under the terms of the MIT license
MODULE mod_fint
IMPLICIT NONE
INTEGER, PARAMETER :: IP4 = SELECTED_INT_KIND(9)
INTEGER, PARAMETER :: WP = SELECTED_REAL_KIND(15,307)
CONTAINS
FUNCTION fint(ndim,x,ng,grid,values)
!Multilinear interpolation in ndim dimensions
IMPLICIT NONE
INTEGER(IP4), INTENT(IN) :: ndim !Number of dimensions
INTEGER(IP4), INTENT(IN) :: ng(ndim) !Number of points along each dimension (i.e., [nx, ny, ...])
REAL(WP), INTENT(IN) :: x(ndim) !The interpolation point
REAL(WP), INTENT(IN) :: grid(SUM(ng)) !Coordinates, one dimension after the other (i.e., [x(:), y(:), ...])
REAL(WP), INTENT(IN) :: values(PRODUCT(ng)) !Tabulated values, linearized in column major format
REAL(WP) :: fint, eta(2,ndim), wei, xb
INTEGER(IP4) :: ind(ndim), i, j0, j1, j2, j
INTEGER(IP4) :: kp, ki, ks, delta, middle
!Loop to find the cell (hypercube) of the table containing the interplation point x
!and compute the resulting interpolation factors eta
j0 = 0
DO i = 1, ndim
j1 = j0 + 1
j2 = j0 + ng(i)
xb = MIN(MAX(x(i),grid(j1)),grid(j2)) !Clipping the interpolation point to the grid
!Binary search to find the cell containing xb
DO
delta = j2-j1
IF (delta<=1) EXIT
middle = j1+delta/2
IF (xb>grid(middle)) THEN
j1 = middle
ELSE
j2 = middle
ENDIF
ENDDO
ind(i) = j1 - j0 !Index of the cell containing xb
eta(2,i) = (xb - grid(j1))/(grid(j1+1)-grid(j1)) !Using x(i) here (instead of xb) performs extrapolation
eta(1,i) = 1.0_WP-eta(2,i)
j0 = j0 + ng(i)
ENDDO
fint = 0.0_WP
!Loop over all the 2^ndim vertices of the cell containing the point
DO j = 1, 2**ndim
wei = 1.0_WP
kp = 1
ki = 1
!Loop over the dimensions to retrieve the index and weight of the given node
!We use the j bit pattern (which is made of ndim bits) to pick the lower/upper node in
!each dimension of the given cell
DO i = 1, ndim
!This magic number will give 1 if the (i-1)-th bit of j-1 is set, 0 otherwise
!It is derived from a more general formula for computing permutations with repetitions
ks = MOD((2*j-1)/2**i,2)
!Retrieve and use the weight for the i-th dimension of the j-th vertex
wei = wei*eta(1+ks,i)
!This is just the i-th step of sub2ind applied to sub(i)=ind(i)+ks which, in the end, returns
!the linear index ki corresponding to subscripts sub(i) for a ndim-dimensional array
!stored in column-major order (as it is values). It is defined as:
!ki - 1 = SUM_i((sub(i)-1)*PROD_j(ng(j),j=1...i-1),i=1...ndim)
ki = ki + (ind(i)+ks-1)*kp
kp = kp*ng(i)
ENDDO
!Summing up this vertex contribution to the final interpolation
fint = fint + wei*values(ki)
ENDDO
ENDFUNCTION fint
ENDMODULE mod_fint
@huijunchen9260
Copy link

This is such an elegant formulation! May I ask more about that magical number, ks? What is the "more general formula for computing permutations with repetitions" means?

@plampite
Copy link
Author

plampite commented Mar 6, 2024

Hi Hui-Jun, thanks for the comment. I will answer with some context. When I started I just knew I needed to loop over each hypercube vertex and sum its contribution with its own weight, but couldn't figure out how to do it generally in N dimensions. Then, while looking at the wikipedia image of the trilinear interpolation, with all those 0 and 1, I realized how all the combinations of all the nodes actually cover each N bit number.

I couldn't readily realize I just needed ks = BTEST(j-1,i-1) (but I ended up not using it because it is slightly slower than the implementation above), so I concluded I just needed an algorithm that would loop over all the numbers in a given range, given the symbols of the numbering system (binary, in this case) and their number (N for me). Why I just tought about this? Because I already met this need in constructing nested for loops with variable number of iterators (so that you can't actually write them because you don't know in advance how many they will be).

Long story short, I needed the k-th permutations, with repetitions, of the elements in v, taken n at the time, which I implemented like this:

DO i = 1, n
   DO j = 1, nk
      res(j,i) = v(MOD((2*k(j)-1)/(2*nv**(n-i)),nv)+1)
   ENDDO
ENDDO

Then I just simplified it, considering that: n is ndim, v only contains 0 and 1 (thus, also, nv=2) and k(j) is simply j (I needed all the permutations in no particular order), running from 1 to 2^ndim. I can't actually remember why I implemented the general formula in this way, or if I took it somewhere (as I said, I had it around from long before, because of the variable for loop).

@huijunchen9260
Copy link

Thanks! I think now I understand, and indeed I tested it, btest is slower than the mod method. That's such a smart way to do it!

@huijunchen9260
Copy link

I just realized using ks = iand(ishft(j-1, -(i-1)), 1) is also sufficiently fast 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment