@@ -461,11 +461,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
461
461
self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
462
462
)
463
463
464
- result_weights = self .weight .index_select (0 , indices .view (- 1 ))
465
- result_scales = self .scales .index_select (0 , indices .view (- 1 ))
464
+
465
+ # result_weights = self.weight.index_select(0, indices.view(-1))
466
+ # result_scales = self.scales.index_select(0, indices.view(-1))
467
+
468
+ weight = self .weight
469
+ scales = self .scales .view (weight .shape [0 ], - 1 )
470
+
471
+ result_weights = F .embedding (indices , weight )
472
+ result_scales = F .embedding (indices , scales )
473
+
474
+ rw_view = result_weights .to (dtype = result_scales .dtype ).view (tuple (result_weights .shape [:- 1 ] + (scales .shape [1 ], - 1 , )))
475
+ rs_view = result_scales .view (tuple (result_scales .shape [:- 1 ]) + (scales .shape [1 ], 1 , ))
476
+ # print(f"rw_view {rw_view.shape}")
477
+ # print(f"rs_view {rs_view.shape}")
466
478
467
- r = result_weights . to ( dtype = result_scales . dtype ) * result_scales
479
+ r = rw_view * rs_view
468
480
return r .view (indices .size () + (- 1 ,))
481
+
482
+ # r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
469
483
470
484
##################################################################
471
485
##### weight only int4 per channel groupwise quantized code ######
0 commit comments