Skip to content

Commit 47406da

Browse files
committed
Reformat TransformerComputeKernelsLayered by adjusting Javadoc annotations, reorganizing line breaks, and improving parameter alignment for enhanced readability and consistency.
1 parent e60f478 commit 47406da

File tree

1 file changed

+115
-89
lines changed

1 file changed

+115
-89
lines changed

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 115 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@ public TransformerComputeKernelsLayered() {
1616
}
1717

1818
/**
19-
* Performs RMS (Root Mean Square) normalization using parallel reduction.
20-
* This is the first phase of RMS normalization that computes the variance
21-
* and scaling factor across all work groups.
19+
* Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups.
2220
*
23-
* Algorithm:
24-
* 1. Each thread computes square of its input element
25-
* 2. Work group performs parallel reduction of squares
26-
* 3. Partial sums stored per work group
27-
* 4. First thread combines all partial sums and computes normalization factor
21+
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. First thread combines all partial
22+
* sums and computes normalization factor
2823
*
29-
* @param context Kernel execution context
30-
* @param output Array to store partial sums and final normalization factor
31-
* @param x Input array to normalize
32-
* @param size Number of elements to process
33-
* @param ermsNorm Epsilon value squared for numerical stability
34-
* @param localMemSize Size of local memory allocation (must match work group size)
24+
* @param context
25+
* Kernel execution context
26+
* @param output
27+
* Array to store partial sums and final normalization factor
28+
* @param x
29+
* Input array to normalize
30+
* @param size
31+
* Number of elements to process
32+
* @param ermsNorm
33+
* Epsilon value squared for numerical stability
34+
* @param localMemSize
35+
* Size of local memory allocation (must match work group size)
3536
*/
3637
public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
3738
int gid = context.globalIdx;
@@ -80,16 +81,20 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray
8081
}
8182

8283
/**
83-
* Applies the computed normalization factor to input and weight elements.
84-
* This is the second phase of RMS normalization.
84+
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
8585
*
8686
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
8787
*
88-
* @param context Kernel execution context
89-
* @param output Array for normalized output
90-
* @param x Input values to normalize
91-
* @param weights Weight values for each element
92-
* @param temp Temporary array containing normalization factor at index 0
88+
* @param context
89+
* Kernel execution context
90+
* @param output
91+
* Array for normalized output
92+
* @param x
93+
* Input values to normalize
94+
* @param weights
95+
* Weight values for each element
96+
* @param temp
97+
* Temporary array containing normalization factor at index 0
9398
*/
9499
public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
95100
int gid = context.globalIdx;
@@ -99,21 +104,26 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray
99104
}
100105

101106
/**
102-
* Copies keys and values into the key-value cache for attention computation.
103-
* Enables efficient access to past key-value pairs during autoregressive generation.
107+
* Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation.
104108
*
105-
* Cache layout: [layer][position][dimension]
106-
* - Each layer has its own key and value cache
107-
* - Each position in sequence has a key and value vector
109+
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector
108110
*
109-
* @param destKeyCache Destination array for key cache
110-
* @param srcKey Source keys to copy
111-
* @param destValueCache Destination array for value cache
112-
* @param srcValue Source values to copy
113-
* @param positioNlayer Array containing current position
114-
* @param kvDim Dimension of key/value vectors
115-
* @param layer Current transformer layer index
116-
* @param contextLength Maximum sequence length
111+
* @param destKeyCache
112+
* Destination array for key cache
113+
* @param srcKey
114+
* Source keys to copy
115+
* @param destValueCache
116+
* Destination array for value cache
117+
* @param srcValue
118+
* Source values to copy
119+
* @param positioNlayer
120+
* Array containing current position
121+
* @param kvDim
122+
* Dimension of key/value vectors
123+
* @param layer
124+
* Current transformer layer index
125+
* @param contextLength
126+
* Maximum sequence length
117127
*/
118128
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
119129

@@ -127,7 +137,6 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float
127137
}
128138
}
129139

130-
131140
public static void copyTo(FloatArray src, int srcOffset, FloatArray dest, int destOffset, int size) {
132141
// Generic copy: src[srcOffset:srcOffset+size] -> dest[destOffset:destOffset+size]
133142
for (@Parallel int i = 0; i < size; i++) {
@@ -144,20 +153,23 @@ public static void copyChunk(FloatArray in, FloatArray out, int dim1In, int dim1
144153
}
145154

146155
/**
147-
* Applies Rotary Position Encoding (RoPE) to query and key vectors.
148-
* RoPE rotates pairs of dimensions based on their position in the sequence,
149-
* enabling the model to learn relative positional information.
156+
* Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional
157+
* information.
150158
*
151-
* For each pair of dimensions (2*i, 2*i+1):
152-
* - Compute rotation angle based on position and frequency
153-
* - Apply 2D rotation to the pair
159+
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair
154160
*
155-
* @param context Kernel execution context
156-
* @param positionHolder Array containing current position
157-
* @param sq Query vectors to rotate
158-
* @param sk Key vectors to rotate
159-
* @param kv_dim Dimension of key/value vectors
160-
* @param head_size Dimension of each attention head
161+
* @param context
162+
* Kernel execution context
163+
* @param positionHolder
164+
* Array containing current position
165+
* @param sq
166+
* Query vectors to rotate
167+
* @param sk
168+
* Key vectors to rotate
169+
* @param kv_dim
170+
* Dimension of key/value vectors
171+
* @param head_size
172+
* Dimension of each attention head
161173
*/
162174
public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
163175
int i = context.globalIdx * 2;
@@ -194,8 +206,9 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold
194206
int dimHalf = head_size / 2;
195207

196208
// Each thread processes one dimension pair
197-
if (idx >= dimHalf)
209+
if (idx >= dimHalf) {
198210
return;
211+
}
199212

200213
int position = positionHolder.get(0);
201214

@@ -209,8 +222,9 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold
209222
int totalDim = sq.getSize();
210223
for (int base = 0; base < totalDim; base += head_size) {
211224
// Skip if we're beyond the bounds
212-
if (base + idx >= totalDim || base + idx + dimHalf >= totalDim)
225+
if (base + idx >= totalDim || base + idx + dimHalf >= totalDim) {
213226
break;
227+
}
214228

215229
// Rotate query
216230
float v0 = sq.get(base + idx);
@@ -719,18 +733,24 @@ public static void matrixVectorGeneric(
719733
// @formatter:on
720734

721735
/**
722-
* Matrix-vector multiplication with residual connection.
723-
* Combines regular matrix multiplication with addition of existing values.
736+
* Matrix-vector multiplication with residual connection. Combines regular matrix multiplication with addition of existing values.
724737
*
725738
* Formula: hb[i] = hb[i] + w[i]·x
726739
*
727-
* @param context Kernel execution context
728-
* @param x Input vector
729-
* @param hb Input/output vector (contains residual, receives result)
730-
* @param w Weight matrix
731-
* @param n Input dimension
732-
* @param d Output dimension
733-
* @param localWorkGroupSize Work group size
740+
* @param context
741+
* Kernel execution context
742+
* @param x
743+
* Input vector
744+
* @param hb
745+
* Input/output vector (contains residual, receives result)
746+
* @param w
747+
* Weight matrix
748+
* @param n
749+
* Input dimension
750+
* @param d
751+
* Output dimension
752+
* @param localWorkGroupSize
753+
* Work group size
734754
*/
735755
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
736756
// One row per workgroup (not per thread)
@@ -753,20 +773,26 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
753773
}
754774

755775
/**
756-
* Fused feed-forward network with SiLU activation and GLU gating.
757-
* Implements the SwiGLU variant used in LLaMA-style models.
776+
* Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models.
758777
*
759-
* Formula: FFN(x) = SiLU(x·W1) ⊙ (x·W3)
760-
* where ⊙ denotes element-wise multiplication
778+
* Formula: FFN(x) = SiLU(x·W1) ⊙ (x·W3) where ⊙ denotes element-wise multiplication
761779
*
762-
* @param context Kernel execution context
763-
* @param x Input vector
764-
* @param hb Output buffer
765-
* @param w1 First feed-forward weight matrix
766-
* @param w3 Third feed-forward weight matrix (gate)
767-
* @param n Input dimension
768-
* @param d Hidden dimension
769-
* @param localWorkGroupSize Work group size
780+
* @param context
781+
* Kernel execution context
782+
* @param x
783+
* Input vector
784+
* @param hb
785+
* Output buffer
786+
* @param w1
787+
* First feed-forward weight matrix
788+
* @param w3
789+
* Third feed-forward weight matrix (gate)
790+
* @param n
791+
* Input dimension
792+
* @param d
793+
* Hidden dimension
794+
* @param localWorkGroupSize
795+
* Work group size
770796
*/
771797
public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) {
772798
// One row per workgroup (not per thread)
@@ -789,10 +815,10 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
789815
}
790816

791817
/**
792-
* Gaussian Error Linear Unit (GELU) activation function.
793-
* Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
818+
* Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
794819
*
795-
* @param x Input value
820+
* @param x
821+
* Input value
796822
* @return Activated value
797823
*/
798824
public static float geluActivation(float x) {
@@ -801,33 +827,33 @@ public static float geluActivation(float x) {
801827
}
802828

803829
/**
804-
* Sigmoid-weighted Linear Unit (SiLU) activation function.
805-
* Also known as Swish activation.
830+
* Sigmoid-weighted Linear Unit (SiLU) activation function. Also known as Swish activation.
806831
*
807832
* Formula: SiLU(x) = x * σ(x) = x / (1 + e^(-x))
808833
*
809-
* @param x Input value
834+
* @param x
835+
* Input value
810836
* @return Activated value
811837
*/
812838
public static float siluActivation(float x) {
813839
return x * (1.0f / (1.0f + TornadoMath.exp(-x)));
814840
}
815841

816842
/**
817-
* Optimized row-major matrix-vector multiplication for a single row.
818-
* Uses parallel reduction within a work group to compute one dot product.
843+
* Optimized row-major matrix-vector multiplication for a single row. Uses parallel reduction within a work group to compute one dot product.
819844
*
820-
* Algorithm:
821-
* 1. Each thread computes partial dot product
822-
* 2. Partial results stored in local memory
823-
* 3. Tree-based reduction combines partial results
824-
* 4. Returns final dot product for the row
845+
* Algorithm: 1. Each thread computes partial dot product 2. Partial results stored in local memory 3. Tree-based reduction combines partial results 4. Returns final dot product for the row
825846
*
826-
* @param context Kernel execution context
827-
* @param localSize Work group size
828-
* @param x Input vector
829-
* @param w Weight matrix row
830-
* @param n Input dimension
847+
* @param context
848+
* Kernel execution context
849+
* @param localSize
850+
* Work group size
851+
* @param x
852+
* Input vector
853+
* @param w
854+
* Weight matrix row
855+
* @param n
856+
* Input dimension
831857
* @return Dot product result for this row
832858
*/
833859
public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, FloatArray w, int n) {
@@ -915,7 +941,7 @@ public static void siluInPlace(FloatArray array, int size) {
915941
// SiLU activation: silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
916942
for (@Parallel int i = 0; i < size; i++) {
917943
float x = array.get(i);
918-
float silu = x / (1.0f + TornadoMath.exp(-x));
944+
float silu = x / (1.0f + TornadoMath.exp(-x));
919945
array.set(i, silu);
920946
}
921947
}

0 commit comments

Comments
 (0)