Skip to content

Commit 00cb08d

Browse files
authored
Merge branch 'main' into export-D74020937
2 parents df104b8 + 70974aa commit 00cb08d

File tree

7 files changed

+703
-27
lines changed

7 files changed

+703
-27
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
6060
// First iteration of reduce will have 32 threads sum up 64 elements.
6161
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
6262
// Thus thread utilization starts at 100%.
63-
#define SHARED_MEMORY_FACTOR 2
63+
#define SHARED_MEMORY_FACTOR 1
6464

65-
#define offset_pos_index(index) ((index) + ((index) >> 2))
65+
#define offset_pos_index(index) ((index) + ((index) >> 3))
6666

6767
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
6868

@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154154
if (all(lessThan(in_pos, out_limits))) {
155155
in_val = load_texel(t_in, in_pos);
156156
}
157-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157+
mean += in_val;
158158
}
159-
160-
reduce_input(width_stride, shared_idx_offset);
161-
mean += shared_input[offset_pos_index(shared_idx_offset)];
162159
}
163160

164-
mean /= width;
161+
shared_input[offset_pos_index(shared_idx)] = mean;
162+
reduce_input(width_stride, shared_idx_offset);
163+
mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165164

166165
memoryBarrierShared();
167166
barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178177
}
179178

180179
const VEC4_T delta = in_val - mean;
181-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180+
var += delta * delta;
182181
}
183-
184-
reduce_input(width_stride, shared_idx_offset);
185-
var += shared_input[offset_pos_index(shared_idx_offset)];
186182
}
187183

188-
var /= width;
184+
shared_input[offset_pos_index(shared_idx)] = var;
185+
reduce_input(width_stride, shared_idx_offset);
186+
var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189187

190188
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191189
VEC4_T offset = -rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226224

227225
const int in_pos_x_limit = out_limits[in_axis_map.x];
228226

227+
VEC4_T accum = VEC4_T(0);
229228
// Loop over the width in stride increments
230229
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231230
// Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244243
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245244
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
246245
}
247-
248-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246+
accum += in_val;
249247
}
250-
251-
reduce_input(width_stride, shared_idx_offset);
252-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253-
mean += val.x + val.y + val.z + val.w;
254248
}
255249

256-
mean /= width;
250+
shared_input[offset_pos_index(shared_idx)] = accum;
251+
reduce_input(width_stride, shared_idx_offset);
252+
VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean = (val.x + val.y + val.z + val.w) / width;
257254

258255
memoryBarrierShared();
259256
barrier();
260257

258+
VEC4_T delta2 = VEC4_T(0);
259+
261260
// Loop over the width in stride increments
262261
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263262
// Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278277
}
279278

280279
const VEC4_T delta = in_val - mean;
281-
const VEC4_T delta2 = delta * delta;
282-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280+
delta2 += delta * delta;
283281
}
284-
285-
reduce_input(width_stride, shared_idx_offset);
286-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287-
var += val.x + val.y + val.z + val.w;
288282
}
289283

290-
var /= width;
284+
shared_input[offset_pos_index(shared_idx)] = delta2;
285+
reduce_input(width_stride, shared_idx_offset);
286+
val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var = (val.x + val.y + val.z + val.w) / width;
291288

292289
T rstd = pow(var + epsilon, T(-0.5));
293290
T offset = -rstd * mean;

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ public enum ModelType {
1414
LLAMA_3_2,
1515
LLAVA_1_5,
1616
LLAMA_GUARD_3,
17+
QWEN_3,
1718
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
2929
case LLAMA_3:
3030
case LLAMA_3_1:
3131
case LLAMA_3_2:
32+
case QWEN_3:
3233
default:
3334
return TEXT_MODEL;
3435
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public static String getSystemPromptTemplate(ModelType modelType) {
2525
+ "<|eot_id|>";
2626
case LLAVA_1_5:
2727
return "USER: ";
28+
case QWEN_3:
29+
return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n";
2830
default:
2931
return SYSTEM_PLACEHOLDER;
3032
}
@@ -42,6 +44,14 @@ public static String getUserPromptTemplate(ModelType modelType) {
4244
+ "<|start_header_id|>assistant<|end_header_id|>";
4345

4446
case LLAVA_1_5:
47+
case QWEN_3:
48+
return "<|im_start|>user\n"
49+
+ USER_PLACEHOLDER
50+
+ "<|im_end|>\n"
51+
+ "<|im_start|>assistant\n"
52+
+ "<think>\n"
53+
+ "\n"
54+
+ "</think>\n\n\n";
4555
default:
4656
return USER_PLACEHOLDER;
4757
}
@@ -69,6 +79,8 @@ public static String getStopToken(ModelType modelType) {
6979
return "<|eot_id|>";
7080
case LLAVA_1_5:
7181
return "</s>";
82+
case QWEN_3:
83+
return "<|endoftext|>";
7284
default:
7385
return "";
7486
}

0 commit comments

Comments
 (0)