@@ -25,6 +25,7 @@ class Netlink:
25
25
NETLINK_ADD_MEMBERSHIP = 1
26
26
NETLINK_CAP_ACK = 10
27
27
NETLINK_EXT_ACK = 11
28
+ NETLINK_GET_STRICT_CHK = 12
28
29
29
30
# Netlink message
30
31
NLMSG_ERROR = 2
@@ -228,6 +229,9 @@ def __init__(self, msg, offset, attr_space=None):
228
229
desc += f" ({ spec ['doc' ]} )"
229
230
self .extack ['miss-type' ] = desc
230
231
232
+ def cmd (self ):
233
+ return self .nl_type
234
+
231
235
def __repr__ (self ):
232
236
msg = f"nl_len = { self .nl_len } ({ len (self .raw )} ) nl_flags = 0x{ self .nl_flags :x} nl_type = { self .nl_type } \n "
233
237
if self .error :
@@ -322,6 +326,9 @@ def __init__(self, nl_msg):
322
326
self .genl_cmd , self .genl_version , _ = struct .unpack_from ("BBH" , nl_msg .raw , 0 )
323
327
self .raw = nl_msg .raw [4 :]
324
328
329
+ def cmd (self ):
330
+ return self .genl_cmd
331
+
325
332
def __repr__ (self ):
326
333
msg = repr (self .nl )
327
334
msg += f"\t genl_cmd = { self .genl_cmd } genl_ver = { self .genl_version } \n "
@@ -330,9 +337,41 @@ def __repr__(self):
330
337
return msg
331
338
332
339
333
- class GenlFamily :
334
- def __init__ (self , family_name ):
340
+ class NetlinkProtocol :
341
+ def __init__ (self , family_name , proto_num ):
335
342
self .family_name = family_name
343
+ self .proto_num = proto_num
344
+
345
+ def _message (self , nl_type , nl_flags , seq = None ):
346
+ if seq is None :
347
+ seq = random .randint (1 , 1024 )
348
+ nlmsg = struct .pack ("HHII" , nl_type , nl_flags , seq , 0 )
349
+ return nlmsg
350
+
351
+ def message (self , flags , command , version , seq = None ):
352
+ return self ._message (command , flags , seq )
353
+
354
+ def _decode (self , nl_msg ):
355
+ return nl_msg
356
+
357
+ def decode (self , ynl , nl_msg ):
358
+ msg = self ._decode (nl_msg )
359
+ fixed_header_size = 0
360
+ if ynl :
361
+ op = ynl .rsp_by_value [msg .cmd ()]
362
+ fixed_header_size = ynl ._fixed_header_size (op )
363
+ msg .raw_attrs = NlAttrs (msg .raw [fixed_header_size :])
364
+ return msg
365
+
366
+ def get_mcast_id (self , mcast_name , mcast_groups ):
367
+ if mcast_name not in mcast_groups :
368
+ raise Exception (f'Multicast group "{ mcast_name } " not present in the spec' )
369
+ return mcast_groups [mcast_name ].value
370
+
371
+
372
+ class GenlProtocol (NetlinkProtocol ):
373
+ def __init__ (self , family_name ):
374
+ super ().__init__ (family_name , Netlink .NETLINK_GENERIC )
336
375
337
376
global genl_family_name_to_id
338
377
if genl_family_name_to_id is None :
@@ -341,6 +380,19 @@ def __init__(self, family_name):
341
380
self .genl_family = genl_family_name_to_id [family_name ]
342
381
self .family_id = genl_family_name_to_id [family_name ]['id' ]
343
382
383
+ def message (self , flags , command , version , seq = None ):
384
+ nlmsg = self ._message (self .family_id , flags , seq )
385
+ genlmsg = struct .pack ("BBH" , command , version , 0 )
386
+ return nlmsg + genlmsg
387
+
388
+ def _decode (self , nl_msg ):
389
+ return GenlMsg (nl_msg )
390
+
391
+ def get_mcast_id (self , mcast_name , mcast_groups ):
392
+ if mcast_name not in self .genl_family ['mcast' ]:
393
+ raise Exception (f'Multicast group "{ mcast_name } " not present in the family' )
394
+ return self .genl_family ['mcast' ][mcast_name ]
395
+
344
396
345
397
#
346
398
# YNL implementation details.
@@ -353,9 +405,19 @@ def __init__(self, def_path, schema=None):
353
405
354
406
self .include_raw = False
355
407
356
- self .sock = socket .socket (socket .AF_NETLINK , socket .SOCK_RAW , Netlink .NETLINK_GENERIC )
408
+ try :
409
+ if self .proto == "netlink-raw" :
410
+ self .nlproto = NetlinkProtocol (self .yaml ['name' ],
411
+ self .yaml ['protonum' ])
412
+ else :
413
+ self .nlproto = GenlProtocol (self .yaml ['name' ])
414
+ except KeyError :
415
+ raise Exception (f"Family '{ self .yaml ['name' ]} ' not supported by the kernel" )
416
+
417
+ self .sock = socket .socket (socket .AF_NETLINK , socket .SOCK_RAW , self .nlproto .proto_num )
357
418
self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_CAP_ACK , 1 )
358
419
self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_EXT_ACK , 1 )
420
+ self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_GET_STRICT_CHK , 1 )
359
421
360
422
self .async_msg_ids = set ()
361
423
self .async_msg_queue = []
@@ -368,18 +430,12 @@ def __init__(self, def_path, schema=None):
368
430
bound_f = functools .partial (self ._op , op_name )
369
431
setattr (self , op .ident_name , bound_f )
370
432
371
- try :
372
- self .family = GenlFamily (self .yaml ['name' ])
373
- except KeyError :
374
- raise Exception (f"Family '{ self .yaml ['name' ]} ' not supported by the kernel" )
375
433
376
434
def ntf_subscribe (self , mcast_name ):
377
- if mcast_name not in self .family .genl_family ['mcast' ]:
378
- raise Exception (f'Multicast group "{ mcast_name } " not present in the family' )
379
-
435
+ mcast_id = self .nlproto .get_mcast_id (mcast_name , self .mcast_groups )
380
436
self .sock .bind ((0 , 0 ))
381
437
self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_ADD_MEMBERSHIP ,
382
- self . family . genl_family [ 'mcast' ][ mcast_name ] )
438
+ mcast_id )
383
439
384
440
def _add_attr (self , space , name , value ):
385
441
try :
@@ -505,11 +561,9 @@ def _decode_extack(self, request, op, extack):
505
561
if 'bad-attr-offs' not in extack :
506
562
return
507
563
508
- genl_req = GenlMsg (NlMsg (request , 0 , op .attr_set ))
509
- fixed_header_size = self ._fixed_header_size (op )
510
- offset = 20 + fixed_header_size
511
- path = self ._decode_extack_path (NlAttrs (genl_req .raw [fixed_header_size :]),
512
- op .attr_set , offset ,
564
+ msg = self .nlproto .decode (self , NlMsg (request , 0 , op .attr_set ))
565
+ offset = 20 + self ._fixed_header_size (op )
566
+ path = self ._decode_extack_path (msg .raw_attrs , op .attr_set , offset ,
513
567
extack ['bad-attr-offs' ])
514
568
if path :
515
569
del extack ['bad-attr-offs' ]
@@ -539,14 +593,17 @@ def _decode_fixed_header(self, msg, name):
539
593
fixed_header_attrs [m .name ] = value
540
594
return fixed_header_attrs
541
595
542
- def handle_ntf (self , nl_msg , genl_msg ):
596
+ def handle_ntf (self , decoded ):
543
597
msg = dict ()
544
598
if self .include_raw :
545
- msg ['nlmsg' ] = nl_msg
546
- msg ['genlmsg' ] = genl_msg
547
- op = self .rsp_by_value [genl_msg .genl_cmd ]
599
+ msg ['raw' ] = decoded
600
+ op = self .rsp_by_value [decoded .cmd ()]
601
+ attrs = self ._decode (decoded .raw_attrs , op .attr_set .name )
602
+ if op .fixed_header :
603
+ attrs .update (self ._decode_fixed_header (decoded , op .fixed_header ))
604
+
548
605
msg ['name' ] = op ['name' ]
549
- msg ['msg' ] = self . _decode ( genl_msg . raw_attrs , op . attr_set . name )
606
+ msg ['msg' ] = attrs
550
607
self .async_msg_queue .append (msg )
551
608
552
609
def check_ntf (self ):
@@ -566,12 +623,12 @@ def check_ntf(self):
566
623
print ("Netlink done while checking for ntf!?" )
567
624
continue
568
625
569
- gm = GenlMsg ( nl_msg )
570
- if gm . genl_cmd not in self .async_msg_ids :
571
- print ("Unexpected msg id done while checking for ntf" , gm )
626
+ decoded = self . nlproto . decode ( self , nl_msg )
627
+ if decoded . cmd () not in self .async_msg_ids :
628
+ print ("Unexpected msg id done while checking for ntf" , decoded )
572
629
continue
573
630
574
- self .handle_ntf (nl_msg , gm )
631
+ self .handle_ntf (decoded )
575
632
576
633
def operation_do_attributes (self , name ):
577
634
"""
@@ -592,7 +649,7 @@ def _op(self, method, vals, dump=False):
592
649
nl_flags |= Netlink .NLM_F_DUMP
593
650
594
651
req_seq = random .randint (1024 , 65535 )
595
- msg = _genl_msg ( self .family . family_id , nl_flags , op .req_value , 1 , req_seq )
652
+ msg = self .nlproto . message ( nl_flags , op .req_value , 1 , req_seq )
596
653
fixed_header_members = []
597
654
if op .fixed_header :
598
655
fixed_header_members = self .consts [op .fixed_header ].members
@@ -624,19 +681,20 @@ def _op(self, method, vals, dump=False):
624
681
done = True
625
682
break
626
683
627
- gm = GenlMsg (nl_msg )
684
+ decoded = self .nlproto .decode (self , nl_msg )
685
+
628
686
# Check if this is a reply to our request
629
- if nl_msg .nl_seq != req_seq or gm . genl_cmd != op .rsp_value :
630
- if gm . genl_cmd in self .async_msg_ids :
631
- self .handle_ntf (nl_msg , gm )
687
+ if nl_msg .nl_seq != req_seq or decoded . cmd () != op .rsp_value :
688
+ if decoded . cmd () in self .async_msg_ids :
689
+ self .handle_ntf (decoded )
632
690
continue
633
691
else :
634
- print ('Unexpected message: ' + repr (gm ))
692
+ print ('Unexpected message: ' + repr (decoded ))
635
693
continue
636
694
637
- rsp_msg = self ._decode (NlAttrs ( gm . raw ) , op .attr_set .name )
695
+ rsp_msg = self ._decode (decoded . raw_attrs , op .attr_set .name )
638
696
if op .fixed_header :
639
- rsp_msg .update (self ._decode_fixed_header (gm , op .fixed_header ))
697
+ rsp_msg .update (self ._decode_fixed_header (decoded , op .fixed_header ))
640
698
rsp .append (rsp_msg )
641
699
642
700
if not rsp :
0 commit comments