Skip to content

Commit e46dd90

Browse files
donaldhkuba-moo
authored andcommitted
tools/net/ynl: Add support for netlink-raw families
Refactor the ynl code to encapsulate protocol specifics into NetlinkProtocol and GenlProtocol. Signed-off-by: Donald Hunter <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Jakub Kicinski <[email protected]>
1 parent fb0a06d commit e46dd90

File tree

1 file changed

+91
-33
lines changed

1 file changed

+91
-33
lines changed

tools/net/ynl/lib/ynl.py

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Netlink:
2525
NETLINK_ADD_MEMBERSHIP = 1
2626
NETLINK_CAP_ACK = 10
2727
NETLINK_EXT_ACK = 11
28+
NETLINK_GET_STRICT_CHK = 12
2829

2930
# Netlink message
3031
NLMSG_ERROR = 2
@@ -228,6 +229,9 @@ def __init__(self, msg, offset, attr_space=None):
228229
desc += f" ({spec['doc']})"
229230
self.extack['miss-type'] = desc
230231

232+
def cmd(self):
233+
return self.nl_type
234+
231235
def __repr__(self):
232236
msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
233237
if self.error:
@@ -322,6 +326,9 @@ def __init__(self, nl_msg):
322326
self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
323327
self.raw = nl_msg.raw[4:]
324328

329+
def cmd(self):
330+
return self.genl_cmd
331+
325332
def __repr__(self):
326333
msg = repr(self.nl)
327334
msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
@@ -330,9 +337,41 @@ def __repr__(self):
330337
return msg
331338

332339

333-
class GenlFamily:
334-
def __init__(self, family_name):
340+
class NetlinkProtocol:
341+
def __init__(self, family_name, proto_num):
335342
self.family_name = family_name
343+
self.proto_num = proto_num
344+
345+
def _message(self, nl_type, nl_flags, seq=None):
346+
if seq is None:
347+
seq = random.randint(1, 1024)
348+
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
349+
return nlmsg
350+
351+
def message(self, flags, command, version, seq=None):
352+
return self._message(command, flags, seq)
353+
354+
def _decode(self, nl_msg):
355+
return nl_msg
356+
357+
def decode(self, ynl, nl_msg):
358+
msg = self._decode(nl_msg)
359+
fixed_header_size = 0
360+
if ynl:
361+
op = ynl.rsp_by_value[msg.cmd()]
362+
fixed_header_size = ynl._fixed_header_size(op)
363+
msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:])
364+
return msg
365+
366+
def get_mcast_id(self, mcast_name, mcast_groups):
367+
if mcast_name not in mcast_groups:
368+
raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
369+
return mcast_groups[mcast_name].value
370+
371+
372+
class GenlProtocol(NetlinkProtocol):
373+
def __init__(self, family_name):
374+
super().__init__(family_name, Netlink.NETLINK_GENERIC)
336375

337376
global genl_family_name_to_id
338377
if genl_family_name_to_id is None:
@@ -341,6 +380,19 @@ def __init__(self, family_name):
341380
self.genl_family = genl_family_name_to_id[family_name]
342381
self.family_id = genl_family_name_to_id[family_name]['id']
343382

383+
def message(self, flags, command, version, seq=None):
384+
nlmsg = self._message(self.family_id, flags, seq)
385+
genlmsg = struct.pack("BBH", command, version, 0)
386+
return nlmsg + genlmsg
387+
388+
def _decode(self, nl_msg):
389+
return GenlMsg(nl_msg)
390+
391+
def get_mcast_id(self, mcast_name, mcast_groups):
392+
if mcast_name not in self.genl_family['mcast']:
393+
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
394+
return self.genl_family['mcast'][mcast_name]
395+
344396

345397
#
346398
# YNL implementation details.
@@ -353,9 +405,19 @@ def __init__(self, def_path, schema=None):
353405

354406
self.include_raw = False
355407

356-
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
408+
try:
409+
if self.proto == "netlink-raw":
410+
self.nlproto = NetlinkProtocol(self.yaml['name'],
411+
self.yaml['protonum'])
412+
else:
413+
self.nlproto = GenlProtocol(self.yaml['name'])
414+
except KeyError:
415+
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
416+
417+
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
357418
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
358419
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
420+
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
359421

360422
self.async_msg_ids = set()
361423
self.async_msg_queue = []
@@ -368,18 +430,12 @@ def __init__(self, def_path, schema=None):
368430
bound_f = functools.partial(self._op, op_name)
369431
setattr(self, op.ident_name, bound_f)
370432

371-
try:
372-
self.family = GenlFamily(self.yaml['name'])
373-
except KeyError:
374-
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
375433

376434
def ntf_subscribe(self, mcast_name):
377-
if mcast_name not in self.family.genl_family['mcast']:
378-
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
379-
435+
mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
380436
self.sock.bind((0, 0))
381437
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
382-
self.family.genl_family['mcast'][mcast_name])
438+
mcast_id)
383439

384440
def _add_attr(self, space, name, value):
385441
try:
@@ -505,11 +561,9 @@ def _decode_extack(self, request, op, extack):
505561
if 'bad-attr-offs' not in extack:
506562
return
507563

508-
genl_req = GenlMsg(NlMsg(request, 0, op.attr_set))
509-
fixed_header_size = self._fixed_header_size(op)
510-
offset = 20 + fixed_header_size
511-
path = self._decode_extack_path(NlAttrs(genl_req.raw[fixed_header_size:]),
512-
op.attr_set, offset,
564+
msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
565+
offset = 20 + self._fixed_header_size(op)
566+
path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
513567
extack['bad-attr-offs'])
514568
if path:
515569
del extack['bad-attr-offs']
@@ -539,14 +593,17 @@ def _decode_fixed_header(self, msg, name):
539593
fixed_header_attrs[m.name] = value
540594
return fixed_header_attrs
541595

542-
def handle_ntf(self, nl_msg, genl_msg):
596+
def handle_ntf(self, decoded):
543597
msg = dict()
544598
if self.include_raw:
545-
msg['nlmsg'] = nl_msg
546-
msg['genlmsg'] = genl_msg
547-
op = self.rsp_by_value[genl_msg.genl_cmd]
599+
msg['raw'] = decoded
600+
op = self.rsp_by_value[decoded.cmd()]
601+
attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
602+
if op.fixed_header:
603+
attrs.update(self._decode_fixed_header(decoded, op.fixed_header))
604+
548605
msg['name'] = op['name']
549-
msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
606+
msg['msg'] = attrs
550607
self.async_msg_queue.append(msg)
551608

552609
def check_ntf(self):
@@ -566,12 +623,12 @@ def check_ntf(self):
566623
print("Netlink done while checking for ntf!?")
567624
continue
568625

569-
gm = GenlMsg(nl_msg)
570-
if gm.genl_cmd not in self.async_msg_ids:
571-
print("Unexpected msg id done while checking for ntf", gm)
626+
decoded = self.nlproto.decode(self, nl_msg)
627+
if decoded.cmd() not in self.async_msg_ids:
628+
print("Unexpected msg id done while checking for ntf", decoded)
572629
continue
573630

574-
self.handle_ntf(nl_msg, gm)
631+
self.handle_ntf(decoded)
575632

576633
def operation_do_attributes(self, name):
577634
"""
@@ -592,7 +649,7 @@ def _op(self, method, vals, dump=False):
592649
nl_flags |= Netlink.NLM_F_DUMP
593650

594651
req_seq = random.randint(1024, 65535)
595-
msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
652+
msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
596653
fixed_header_members = []
597654
if op.fixed_header:
598655
fixed_header_members = self.consts[op.fixed_header].members
@@ -624,19 +681,20 @@ def _op(self, method, vals, dump=False):
624681
done = True
625682
break
626683

627-
gm = GenlMsg(nl_msg)
684+
decoded = self.nlproto.decode(self, nl_msg)
685+
628686
# Check if this is a reply to our request
629-
if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
630-
if gm.genl_cmd in self.async_msg_ids:
631-
self.handle_ntf(nl_msg, gm)
687+
if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
688+
if decoded.cmd() in self.async_msg_ids:
689+
self.handle_ntf(decoded)
632690
continue
633691
else:
634-
print('Unexpected message: ' + repr(gm))
692+
print('Unexpected message: ' + repr(decoded))
635693
continue
636694

637-
rsp_msg = self._decode(NlAttrs(gm.raw), op.attr_set.name)
695+
rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
638696
if op.fixed_header:
639-
rsp_msg.update(self._decode_fixed_header(gm, op.fixed_header))
697+
rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
640698
rsp.append(rsp_msg)
641699

642700
if not rsp:

0 commit comments

Comments
 (0)