Skip to content

Commit eceedac

Browse files
committed
Update base for Update on "[ET][EZ] Enable operator<< for Half tensor data"
Useful for debugging `Half` i.e. `fp16` models, when we have `EValue`s that are `Half` dtype and we do the following: ``` std::cout << "===== INPUT =====" << std::endl; for (EValue& v : inputs) { std::cout << v << std::endl; } std::cout << "===== OUTPUT =====" << std::endl; for (EValue& v : outputs) { std::cout << v << std::endl; } ``` ## Before ``` ===== INPUT ===== tensor(sizes=[1, 3, 96, 72], [<unhandled scalar type 5>]) ===== OUTPUT ===== tensor(sizes=[1, 2, 96, 72], [<unhandled scalar type 5>]) ``` ## After ``` ===== INPUT ===== tensor(sizes=[1, 3, 96, 72], [0.279785, 0.271484, 0.364746, ..., 0.150391, 0.836426, 0.019043]) ===== OUTPUT ===== tensor(sizes=[1, 2, 96, 72], [18.2344, -10.0938, 1.35059, ..., -33.6875, 4.07422, -22.5312]) ``` Differential Revision: [D57977366](https://our.internmc.facebook.com/intern/diff/D57977366/) [ghstack-poisoned]
1 parent a4ffb2a commit eceedac

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#version 450 core
1010

1111
#define PRECISION ${PRECISION}
12-
#define FLT_MIN -3.402823466e+38
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
1314

1415
#include "indexing_utils.h"
1516

@@ -36,7 +37,7 @@ void main() {
3637
const ivec2 start = max(ivec2(0), ipos);
3738
const ivec2 end = min(ipos + kernel_size, ivec2(in_sizes.xy));
3839

39-
vec4 sum = vec4(0);
40+
VEC4_T sum = VEC4_T(0);
4041
for (int y = start.y; y < end.y; ++y) {
4142
for (int x = start.x; x < end.x; ++x) {
4243
sum += texelFetch(t_in, ivec3(x, y, pos.z), 0);

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16+
- VALUE: int
1617
shader_variants:
1718
- NAME: avg_pool2d

0 commit comments

Comments
 (0)