Skip to content

Commit e38caba

Browse files
committed
Add Enumerable#sample(n=1, ...)
- Only support the case of n == 1
1 parent 7b582da commit e38caba

2 files changed

Lines changed: 213 additions & 0 deletions

File tree

ext/enumerable/statistics/extension/statistics.c

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,124 @@ enum_stdev(int argc, VALUE* argv, VALUE obj)
14131413
return stdev;
14141414
}
14151415

1416+
#if SIZEOF_SIZE_T == SIZEOF_LONG
1417+
static inline size_t
1418+
random_usize_limited(VALUE rnd, size_t max)
1419+
{
1420+
return (size_t)rb_random_ulong_limited(rnd, max);
1421+
}
1422+
#else
1423+
static inline size_t
1424+
random_usize_limited(VALUE rnd, size_t max)
1425+
{
1426+
if (max <= ULONG_MAX) {
1427+
return (size_t)rb_random_ulong_limited(rnd, (unsigned long)max);
1428+
}
1429+
else {
1430+
VALUE num = rb_random_int(rnd, SIZET2NUM(max));
1431+
return NUM2SIZET(num);
1432+
}
1433+
}
1434+
#endif
1435+
1436+
struct sample_single_memo {
1437+
size_t k;
1438+
VALUE sample;
1439+
VALUE random;
1440+
};
1441+
1442+
static VALUE
1443+
enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
1444+
{
1445+
struct sample_single_memo *memo = (struct sample_single_memo *)data;
1446+
ENUM_WANT_SVALUE();
1447+
1448+
if (++memo->k <= 1) {
1449+
memo->sample = e;
1450+
}
1451+
else {
1452+
size_t j = random_usize_limited(memo->random, memo->k - 1);
1453+
if (j == 1) {
1454+
memo->sample = e;
1455+
}
1456+
}
1457+
1458+
return Qnil;
1459+
}
1460+
1461+
static VALUE
1462+
enum_sample_single(VALUE obj, VALUE random)
1463+
{
1464+
struct sample_single_memo memo;
1465+
1466+
memo.k = 0;
1467+
memo.sample = Qundef;
1468+
memo.random = random;
1469+
1470+
rb_block_call(obj, id_each, 0, 0, enum_sample_single_i, (VALUE)&memo);
1471+
1472+
return memo.sample;
1473+
}
1474+
1475+
static VALUE
1476+
enum_sample_multiple_unweighted(VALUE obj, long size, int replace_p)
1477+
{
1478+
assert(size > 1);
1479+
1480+
return Qnil;
1481+
}
1482+
1483+
/* call-seq:
1484+
* enum.sample(n=1, random: Random, replace: false)
1485+
*/
1486+
static VALUE
1487+
enum_sample(int argc, VALUE *argv, VALUE obj)
1488+
{
1489+
VALUE size_v, random_v, replace_v, weights_v, opts;
1490+
long size;
1491+
int replace_p;
1492+
1493+
random_v = rb_cRandom;
1494+
replace_v = Qundef;
1495+
weights_v = Qundef;
1496+
1497+
if (argc == 0) goto single;
1498+
1499+
rb_scan_args(argc, argv, "01:", &size_v, &opts);
1500+
size = NIL_P(size_v) ? 1 : NUM2LONG(size_v);
1501+
1502+
if (size == 1 && NIL_P(opts)) {
1503+
goto single;
1504+
}
1505+
1506+
if (!NIL_P(opts)) {
1507+
static ID keywords[3];
1508+
VALUE kwargs[3];
1509+
if (!keywords[0]) {
1510+
keywords[0] = rb_intern("random");
1511+
keywords[1] = rb_intern("replace");
1512+
/* keywords[2] = rb_intern("weights"); */
1513+
}
1514+
rb_get_kwargs(opts, keywords, 0, 2, kwargs);
1515+
random_v = kwargs[0];
1516+
replace_v = kwargs[1];
1517+
/* weights_v = kwargs[2]; */
1518+
}
1519+
if (random_v == Qundef) {
1520+
random_v = rb_cRandom;
1521+
}
1522+
1523+
if (size == 1) {
1524+
single:
1525+
return enum_sample_single(obj, random_v);
1526+
}
1527+
1528+
replace_p = (replace_v == Qundef) ? 1 : RTEST(replace_v);
1529+
1530+
return enum_sample_unweighted(obj, NUM2LONG(size), replace_p);
1531+
}
1532+
1533+
14161534
/* call-seq:
14171535
* ary.mean_stdev(population: false)
14181536
*
@@ -1479,6 +1597,7 @@ Init_extension(void)
14791597
rb_define_method(rb_mEnumerable, "variance", enum_variance, -1);
14801598
rb_define_method(rb_mEnumerable, "mean_stdev", enum_mean_stdev, -1);
14811599
rb_define_method(rb_mEnumerable, "stdev", enum_stdev, -1);
1600+
rb_define_method(rb_mEnumerable, "sample", enum_sample, -1);
14821601

14831602
#ifndef HAVE_ARRAY_SUM
14841603
rb_define_method(rb_cArray, "sum", ary_sum, -1);

spec/enum/sample_spec.rb

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
require 'spec_helper'
2+
require 'enumerable/statistics'
3+
4+
RSpec.describe Enumerable, '#sample' do
5+
let(:random) { Random.new }
6+
let(:n) { 20 }
7+
8+
context 'without weight' do
9+
let(:enum) { 1.upto(100000) }
10+
11+
context 'without size' do
12+
context 'without rng' do
13+
context 'without weight' do
14+
specify do
15+
result = enum.sample
16+
expect(result).to be_an(Integer)
17+
other_results = Array.new(100) { enum.sample }
18+
expect(other_results).not_to be_all(eq result)
19+
end
20+
end
21+
end
22+
23+
context 'with rng' do
24+
specify do
25+
save_random = random.dup
26+
result = enum.sample(random: random)
27+
expect(result).to be_an(Integer)
28+
other_results = Array.new(100) { enum.sample(random: save_random.dup) }
29+
expect(other_results).to be_all(eq result)
30+
end
31+
end
32+
end
33+
34+
context 'with size (== 1)' do
35+
context 'without rng' do
36+
context 'without weight' do
37+
specify do
38+
result = enum.sample(1)
39+
expect(result).to be_an(Integer)
40+
other_results = Array.new(100) { enum.sample(1) }
41+
expect(other_results).not_to be_all(eq result)
42+
end
43+
end
44+
end
45+
46+
context 'with rng' do
47+
specify do
48+
save_random = random.dup
49+
result = enum.sample(1, random: random)
50+
expect(result).to be_an(Integer)
51+
other_results = Array.new(100) { enum.sample(1, random: save_random.dup) }
52+
expect(other_results).to be_all(eq result)
53+
end
54+
end
55+
end
56+
57+
context 'with size (> 1)' do
58+
context 'without replacement' do
59+
context 'without rng' do
60+
subject(:result) { enum.sample(n) }
61+
62+
specify do
63+
result = enum.sample(n)
64+
expect(result).to be_an(Array)
65+
expect(result.length).to eq(n)
66+
other_results = Array.new(100) { enum.sample(n) }
67+
expect(other_results).not_to be_all(eq result)
68+
end
69+
end
70+
71+
context 'with rng' do
72+
subject(:result) { enum.sample(n, random: random) }
73+
74+
specify do
75+
save_random = random.dup
76+
result = enum.sample(n, random: random)
77+
expect(result).to be_an(Array)
78+
expect(result.length).to eq(n)
79+
other_results = Array.new(100) { enum.sample(n, random: random) }
80+
expect(other_results).to be_all(eq result)
81+
end
82+
end
83+
end
84+
85+
context 'with replacement' do
86+
pending
87+
end
88+
end
89+
end
90+
91+
context 'with weight' do
92+
pending
93+
end
94+
end

0 commit comments

Comments
 (0)