You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Reformat TransformerComputeKernelsLayered by adjusting Javadoc annotations, reorganizing line breaks, and improving parameter alignment for enhanced readability and consistency.
@@ -16,22 +16,23 @@ public TransformerComputeKernelsLayered() {
16
16
}
17
17
18
18
/**
19
-
* Performs RMS (Root Mean Square) normalization using parallel reduction.
20
-
* This is the first phase of RMS normalization that computes the variance
21
-
* and scaling factor across all work groups.
19
+
* Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups.
22
20
*
23
-
* Algorithm:
24
-
* 1. Each thread computes square of its input element
25
-
* 2. Work group performs parallel reduction of squares
26
-
* 3. Partial sums stored per work group
27
-
* 4. First thread combines all partial sums and computes normalization factor
21
+
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. First thread combines all partial
22
+
* sums and computes normalization factor
28
23
*
29
-
* @param context Kernel execution context
30
-
* @param output Array to store partial sums and final normalization factor
31
-
* @param x Input array to normalize
32
-
* @param size Number of elements to process
33
-
* @param ermsNorm Epsilon value squared for numerical stability
34
-
* @param localMemSize Size of local memory allocation (must match work group size)
24
+
* @param context
25
+
* Kernel execution context
26
+
* @param output
27
+
* Array to store partial sums and final normalization factor
28
+
* @param x
29
+
* Input array to normalize
30
+
* @param size
31
+
* Number of elements to process
32
+
* @param ermsNorm
33
+
* Epsilon value squared for numerical stability
34
+
* @param localMemSize
35
+
* Size of local memory allocation (must match work group size)
@@ -99,21 +104,26 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray
99
104
}
100
105
101
106
/**
102
-
* Copies keys and values into the key-value cache for attention computation.
103
-
* Enables efficient access to past key-value pairs during autoregressive generation.
107
+
* Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation.
104
108
*
105
-
* Cache layout: [layer][position][dimension]
106
-
* - Each layer has its own key and value cache
107
-
* - Each position in sequence has a key and value vector
109
+
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector
108
110
*
109
-
* @param destKeyCache Destination array for key cache
110
-
* @param srcKey Source keys to copy
111
-
* @param destValueCache Destination array for value cache
112
-
* @param srcValue Source values to copy
113
-
* @param positioNlayer Array containing current position
@@ -144,20 +153,23 @@ public static void copyChunk(FloatArray in, FloatArray out, int dim1In, int dim1
144
153
}
145
154
146
155
/**
147
-
* Applies Rotary Position Encoding (RoPE) to query and key vectors.
148
-
* RoPE rotates pairs of dimensions based on their position in the sequence,
149
-
* enabling the model to learn relative positional information.
156
+
* Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional
157
+
* information.
150
158
*
151
-
* For each pair of dimensions (2*i, 2*i+1):
152
-
* - Compute rotation angle based on position and frequency
153
-
* - Apply 2D rotation to the pair
159
+
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair
154
160
*
155
-
* @param context Kernel execution context
156
-
* @param positionHolder Array containing current position
157
-
* @param sq Query vectors to rotate
158
-
* @param sk Key vectors to rotate
159
-
* @param kv_dim Dimension of key/value vectors
160
-
* @param head_size Dimension of each attention head
* Algorithm: 1. Each thread computes partial dot product 2. Partial results stored in local memory 3. Tree-based reduction combines partial results 4. Returns final dot product for the row
0 commit comments