Skip to content

Commit 317886e

Browse files
committed
Support n > 1 in Enumerable#sample
1 parent e38caba commit 317886e

2 files changed

Lines changed: 52 additions & 14 deletions

File tree

ext/enumerable/statistics/extension/statistics.c

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,16 +1433,17 @@ random_usize_limited(VALUE rnd, size_t max)
14331433
}
14341434
#endif
14351435

1436-
struct sample_single_memo {
1436+
struct enum_sample_memo {
14371437
size_t k;
1438+
long n;
14381439
VALUE sample;
14391440
VALUE random;
14401441
};
14411442

14421443
static VALUE
14431444
enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14441445
{
1445-
struct sample_single_memo *memo = (struct sample_single_memo *)data;
1446+
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
14461447
ENUM_WANT_SVALUE();
14471448

14481449
if (++memo->k <= 1) {
@@ -1461,9 +1462,10 @@ enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14611462
static VALUE
14621463
enum_sample_single(VALUE obj, VALUE random)
14631464
{
1464-
struct sample_single_memo memo;
1465+
struct enum_sample_memo memo;
14651466

14661467
memo.k = 0;
1468+
memo.n = 1;
14671469
memo.sample = Qundef;
14681470
memo.random = random;
14691471

@@ -1473,13 +1475,46 @@ enum_sample_single(VALUE obj, VALUE random)
14731475
}
14741476

14751477
static VALUE
1476-
enum_sample_multiple_unweighted(VALUE obj, long size, int replace_p)
1478+
enum_sample_multiple_without_replace_unweighted_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14771479
{
1478-
assert(size > 1);
1480+
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
1481+
ENUM_WANT_SVALUE();
1482+
1483+
if (++memo->k <= memo->n) {
1484+
rb_ary_push(memo->sample, e);
1485+
}
1486+
else {
1487+
size_t j = random_usize_limited(memo->random, memo->k - 1);
1488+
if (j <= memo->n) {
1489+
rb_ary_store(memo->sample, (long)(j - 1), e);
1490+
}
1491+
}
14791492

14801493
return Qnil;
14811494
}
14821495

1496+
static VALUE
1497+
enum_sample_multiple_unweighted(VALUE obj, long size, VALUE random, int replace_p)
1498+
{
1499+
struct enum_sample_memo memo;
1500+
1501+
assert(size > 1);
1502+
1503+
memo.k = 0;
1504+
memo.n = size;
1505+
memo.sample = rb_ary_new_capa(size);
1506+
memo.random = random;
1507+
1508+
if (replace_p) {
1509+
return Qnil;
1510+
}
1511+
else {
1512+
rb_block_call(obj, id_each, 0, 0, enum_sample_multiple_without_replace_unweighted_i, (VALUE)&memo);
1513+
}
1514+
1515+
return memo.sample;
1516+
}
1517+
14831518
/* call-seq:
14841519
* enum.sample(n=1, random: Random, replace: false)
14851520
*/
@@ -1516,6 +1551,7 @@ enum_sample(int argc, VALUE *argv, VALUE obj)
15161551
replace_v = kwargs[1];
15171552
/* weights_v = kwargs[2]; */
15181553
}
1554+
15191555
if (random_v == Qundef) {
15201556
random_v = rb_cRandom;
15211557
}
@@ -1525,9 +1561,9 @@ enum_sample(int argc, VALUE *argv, VALUE obj)
15251561
return enum_sample_single(obj, random_v);
15261562
}
15271563

1528-
replace_p = (replace_v == Qundef) ? 1 : RTEST(replace_v);
1564+
replace_p = (replace_v == Qundef) ? 0 : RTEST(replace_v);
15291565

1530-
return enum_sample_unweighted(obj, NUM2LONG(size), replace_p);
1566+
return enum_sample_multiple_unweighted(obj, size, random_v, replace_p);
15311567
}
15321568

15331569

spec/enum/sample_spec.rb

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
result = enum.sample
1616
expect(result).to be_an(Integer)
1717
other_results = Array.new(100) { enum.sample }
18-
expect(other_results).not_to be_all(eq result)
18+
expect(other_results).not_to be_all {|i| i == result }
1919
end
2020
end
2121
end
@@ -26,7 +26,7 @@
2626
result = enum.sample(random: random)
2727
expect(result).to be_an(Integer)
2828
other_results = Array.new(100) { enum.sample(random: save_random.dup) }
29-
expect(other_results).to be_all(eq result)
29+
expect(other_results).to be_all {|i| i == result }
3030
end
3131
end
3232
end
@@ -38,7 +38,7 @@
3838
result = enum.sample(1)
3939
expect(result).to be_an(Integer)
4040
other_results = Array.new(100) { enum.sample(1) }
41-
expect(other_results).not_to be_all(eq result)
41+
expect(other_results).not_to be_all {|i| i == result }
4242
end
4343
end
4444
end
@@ -49,7 +49,7 @@
4949
result = enum.sample(1, random: random)
5050
expect(result).to be_an(Integer)
5151
other_results = Array.new(100) { enum.sample(1, random: save_random.dup) }
52-
expect(other_results).to be_all(eq result)
52+
expect(other_results).to be_all {|i| i == result }
5353
end
5454
end
5555
end
@@ -63,8 +63,9 @@
6363
result = enum.sample(n)
6464
expect(result).to be_an(Array)
6565
expect(result.length).to eq(n)
66+
expect(result.uniq.length).to eq(n)
6667
other_results = Array.new(100) { enum.sample(n) }
67-
expect(other_results).not_to be_all(eq result)
68+
expect(other_results).not_to be_all {|i| i == result }
6869
end
6970
end
7071

@@ -76,8 +77,9 @@
7677
result = enum.sample(n, random: random)
7778
expect(result).to be_an(Array)
7879
expect(result.length).to eq(n)
79-
other_results = Array.new(100) { enum.sample(n, random: random) }
80-
expect(other_results).to be_all(eq result)
80+
expect(result.uniq.length).to eq(n)
81+
other_results = Array.new(100) { enum.sample(n, random: save_random.dup) }
82+
expect(other_results).to be_all {|i| i == result }
8183
end
8284
end
8385
end

0 commit comments

Comments
 (0)