只需循环遍历每个键,并使用除键的最后一个元素之外的所有元素来添加词典.保持对如此设置的最后一个字典的引用,然后使用键元组中的最后一个元素在输出字典中实际设置键值对:
def nest(d: dict) -> dict:
result = {}
for key, value in d.items():
target = result
for k in key[:-1]: # traverse all keys but the last
target = target.setdefault(k, {})
target[key[-1]] = value
return result
from functools import reduce
def nest(d: dict) -> dict:
result = {}
traverse = lambda r, k: r.setdefault(k, {})
for key, value in d.items():
reduce(traverse, key[:-1], result)[key[-1]] = value
return result
我使用了dict.setdefault()而不是auto-vivication defaultdict(nested_dict)选项,因为这会生成一个常规字典,当它们尚不存在时不会进一步隐式添加密钥.
演示:
>>> from pprint import pprint
>>> pprint(nest(d1))
{'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}
>>> pprint(nest(d2))
{'A': {0: {False: 1, True: 0}, 1: {False: 3, True: 2}},
'B': {0: {False: 5, True: 4}, 1: {False: 7, True: 6}}}
>>> pprint(nest(d3))
{'C': {0: {False: {'A': 2, 'B': 3}, True: {'A': 0, 'B': 1}},
1: {False: {'A': 6, 'B': 7}, True: {'A': 4, 'B': 5}}},
'D': {0: {False: {'A': 10, 'B': 11}, True: {'A': 8, 'B': 9}},
1: {False: {'A': 14, 'B': 15}, True: {'A': 12, 'B': 13}}}}
注意,上面的解决方案是一个干净的O(N)循环(N是输入字典的长度),而Ajax1234提出的groupby解决方案必须对输入进行排序才能工作,从而使其成为O(NlogN)解决方案.这意味着对于具有1000个元素的字典,groupby()将需要10.000步来产生输出,而O(N)线性循环仅需要1000步.对于一百万个键,这增加到2000万步等.
此外,Python中的递归速度很慢,因为Python无法将这些解决方案优化为迭代方法.函数调用相对昂贵,因此使用递归会带来显着的性能成本,因为您会大大增加函数调用的数量并扩展帧堆栈操作.
计时表明这有多重要;使用您的样品d3和100k运行,我们很容易看到5倍的速度差异:
>>> from timeit import timeit
>>> timeit('n(d)', 'from __main__ import create_nested_dict as n, d3; d=d3.items()', number=100_000)
8.210276518017054
>>> timeit('n(d)', 'from __main__ import nest as n, d3 as d', number=100_000)
1.6089267160277814