Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support configurable LMUL in VectorXoshiro. #2046

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Support configurable LMUL in VectorXoshiro.
PiperOrigin-RevId: 620864942
  • Loading branch information
tgale96 authored and copybara-github committed Apr 2, 2024
commit 7c2fbaf32c046ecb501fd758a358d3160d0803fa
29 changes: 16 additions & 13 deletions hwy/contrib/random/random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,21 @@ class Xoshiro {

} // namespace internal

template <int kPow2 = 0>
class VectorXoshiro {
private:
using VU64 = Vec<ScalableTag<std::uint64_t>>;
using TagU64 = ScalableTag<std::uint64_t, kPow2>;
using TagF64 = ScalableTag<double, kPow2>;

using VU64 = Vec<TagU64>;
using StateType = AlignedNDArray<std::uint64_t, 2>;
#if HWY_HAVE_FLOAT64
using VF64 = Vec<ScalableTag<double>>;
using VF64 = Vec<TagF64>;
#endif
public:
explicit VectorXoshiro(const std::uint64_t seed,
const std::uint64_t threadNumber = 0)
: state_{{internal::Xoshiro::StateSize(),
Lanes(ScalableTag<std::uint64_t>{})}},
: state_{{internal::Xoshiro::StateSize(), Lanes(TagU64{})}},
streams{state_.shape().back()} {
internal::Xoshiro xoshiro{seed};

Expand All @@ -202,7 +205,7 @@ class VectorXoshiro {

AlignedVector<std::uint64_t> operator()(const std::size_t n) {
AlignedVector<std::uint64_t> result(n);
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand All @@ -221,7 +224,7 @@ class VectorXoshiro {
template <std::uint64_t N>
std::array<std::uint64_t, N> operator()() noexcept {
alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand All @@ -246,7 +249,7 @@ class VectorXoshiro {
#if HWY_HAVE_FLOAT64

HWY_INLINE VF64 Uniform() noexcept {
const ScalableTag<double> real_tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
const auto bits = ShiftRight<11>(Next());
const auto real = ConvertTo(real_tag, bits);
Expand All @@ -255,8 +258,8 @@ class VectorXoshiro {

AlignedVector<double> Uniform(const std::size_t n) {
AlignedVector<double> result(n);
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const TagU64 tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);

auto s0 = Load(tag, state_[{0}].data());
Expand All @@ -282,8 +285,8 @@ class VectorXoshiro {
template <std::uint64_t N>
std::array<double, N> Uniform() noexcept {
alignas(HWY_ALIGNMENT) std::array<double, N> result;
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const TagU64 tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);

auto s0 = Load(tag, state_[{0}].data());
Expand Down Expand Up @@ -326,7 +329,7 @@ class VectorXoshiro {
}

HWY_INLINE VU64 Next() noexcept {
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand Down Expand Up @@ -368,7 +371,7 @@ class CachedXoshiro {
}

private:
VectorXoshiro generator_;
VectorXoshiro<> generator_;
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
std::size_t index_;

Expand Down
16 changes: 8 additions & 8 deletions hwy/contrib/random/random_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ std::uint64_t GetSeed() { return static_cast<uint64_t>(std::time(nullptr)); }
void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
const size_t size) {
const ScalableTag<std::uint64_t> d;
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
for (size_t i = 0; i < size; i += Lanes(d)) {
Store(generator(), d, result + i);
}
Expand All @@ -40,7 +40,7 @@ void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,
const size_t size) {
const ScalableTag<double> d;
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
for (size_t i = 0; i < size; i += Lanes(d)) {
Store(generator.Uniform(), d, result + i);
}
Expand All @@ -49,7 +49,7 @@ void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,

void TestSeeding() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
internal::Xoshiro reference{seed};
const auto& state = generator.GetState();
const ScalableTag<std::uint64_t> d;
Expand All @@ -72,7 +72,7 @@ void TestSeeding() {
void TestMultiThreadSeeding() {
const std::uint64_t seed = GetSeed();
const std::uint64_t threadId = std::random_device()() % 1000;
VectorXoshiro generator{seed, threadId};
VectorXoshiro<> generator{seed, threadId};
internal::Xoshiro reference{seed};

for (std::size_t i = 0UL; i < threadId; ++i) {
Expand Down Expand Up @@ -146,7 +146,7 @@ void TestUniformDist() {

void TestNextNRandomUint64() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.operator()(tests);
std::vector<internal::Xoshiro> reference;
reference.emplace_back(seed);
Expand Down Expand Up @@ -174,7 +174,7 @@ void TestNextNRandomUint64() {

void TestNextFixedNRandomUint64() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.operator()<tests>();
std::vector<internal::Xoshiro> reference;
reference.emplace_back(seed);
Expand Down Expand Up @@ -203,7 +203,7 @@ void TestNextFixedNRandomUint64() {
#if HWY_HAVE_FLOAT64
void TestNextNUniformDist() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.Uniform(tests);
internal::Xoshiro reference{seed};
const ScalableTag<double> d;
Expand All @@ -222,7 +222,7 @@ void TestNextNUniformDist() {

void TestNextFixedNUniformDist() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.Uniform<tests>();
internal::Xoshiro reference{seed};
const ScalableTag<double> d;
Expand Down
Loading