我需要写一个加权版的random。选择(列表中的每个元素有不同的被选择的概率)。这是我想到的:
def weightedChoice(choices):
"""Like random.choice, but each element can have a different chance of
being selected.
choices can be any iterable containing iterables with two items each.
Technically, they can have more than two items, the rest will just be
ignored. The first item is the thing being chosen, the second item is
its weight. The weights can be any numeric values, what matters is the
relative differences between them.
"""
space = {}
current = 0
for choice, weight in choices:
if weight > 0:
space[current] = choice
current += weight
rand = random.uniform(0, current)
for key in sorted(space.keys() + [current]):
if rand < key:
return choice
choice = space[key]
return None
这个函数对我来说太复杂了,而且很丑。我希望这里的每个人都能提供一些改进的建议或其他方法。对我来说,效率没有代码的整洁和可读性重要。
假设你有
items = [11, 23, 43, 91]
probability = [0.2, 0.3, 0.4, 0.1]
你有一个函数,它生成一个介于[0,1)之间的随机数(我们可以在这里使用random.random())。
现在求概率的前缀和
prefix_probability=[0.2,0.5,0.9,1]
现在,我们只需取一个0-1之间的随机数,然后使用二分搜索来查找该数字在prefix_probability中的位置。这个索引就是你的答案
代码是这样的
return items[bisect.bisect(prefix_probability,random.random())]
我可能已经来不及提供任何有用的东西了,但这里有一个简单,简短,非常有效的片段:
def choose_index(probabilies):
cmf = probabilies[0]
choice = random.random()
for k in xrange(len(probabilies)):
if choice <= cmf:
return k
else:
cmf += probabilies[k+1]
不需要排序你的概率或用你的cmf创建一个向量,它一旦找到它的选择就会终止。内存:O(1),时间:O(N),平均运行时间~ N/2。
如果你有权重,只需添加一行:
def choose_index(weights):
probabilities = weights / sum(weights)
cmf = probabilies[0]
choice = random.random()
for k in xrange(len(probabilies)):
if choice <= cmf:
return k
else:
cmf += probabilies[k+1]
粗糙的,但可能足够:
import random
weighted_choice = lambda s : random.choice(sum(([v]*wt for v,wt in s),[]))
这有用吗?
# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]
# initialize tally dict
tally = dict.fromkeys(choices, 0)
# tally up 1000 weighted choices
for i in xrange(1000):
tally[weighted_choice(choices)] += 1
print tally.items()
打印:
[('WHITE', 904), ('GREEN', 22), ('RED', 74)]
假设所有权重都是整数。它们的和不一定是100,我这么做只是为了让测试结果更容易理解。(如果权重是浮点数,则将它们都乘以10,直到所有权重>= 1。)
weights = [.6, .2, .001, .199]
while any(w < 1.0 for w in weights):
weights = [w*10 for w in weights]
weights = map(int, weights)