-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Offload] Full AMD support for olMemFill #154958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-offload @llvm/pr-subscribers-backend-amdgpu Author: Ross Brunton (RossBrunton) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154958.diff 3 Files Affected:
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 7ba55715ff58d..4cc158f326add 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -924,6 +924,7 @@ struct AMDGPUStreamTy {
void *Dst;
const void *Src;
size_t Size;
+ size_t NumTimes;
};
/// Utility struct holding arguments for freeing buffers to memory managers.
@@ -974,9 +975,14 @@ struct AMDGPUStreamTy {
StreamSlotTy() : Signal(nullptr), Callbacks({}), ActionArgs({}) {}
/// Schedule a host memory copy action on the slot.
- Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size) {
+ ///
+ /// Num times will repeat the copy that many times, sequentually in the dest
+ /// buffer.
+ Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size,
+ size_t NumTimes = 1) {
Callbacks.emplace_back(memcpyAction);
- ActionArgs.emplace_back().MemcpyArgs = MemcpyArgsTy{Dst, Src, Size};
+ ActionArgs.emplace_back().MemcpyArgs =
+ MemcpyArgsTy{Dst, Src, Size, NumTimes};
return Plugin::success();
}
@@ -1216,7 +1222,12 @@ struct AMDGPUStreamTy {
assert(Args->Dst && "Invalid destination buffer");
assert(Args->Src && "Invalid source buffer");
- std::memcpy(Args->Dst, Args->Src, Args->Size);
+ auto BasePtr = Args->Dst;
+ for (size_t I = 0; I < Args->NumTimes; I++) {
+ std::memcpy(BasePtr, Args->Src, Args->Size);
+ BasePtr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(BasePtr) +
+ Args->Size);
+ }
return Plugin::success();
}
@@ -1421,7 +1432,8 @@ struct AMDGPUStreamTy {
/// manager once the operation completes.
Error pushMemoryCopyH2DAsync(void *Dst, const void *Src, void *Inter,
uint64_t CopySize,
- AMDGPUMemoryManagerTy &MemoryManager) {
+ AMDGPUMemoryManagerTy &MemoryManager,
+ size_t NumTimes = 1) {
// Retrieve available signals for the operation's outputs.
AMDGPUSignalTy *OutputSignals[2] = {};
if (auto Err = SignalManager.getResources(/*Num=*/2, OutputSignals))
@@ -1443,7 +1455,8 @@ struct AMDGPUStreamTy {
// The std::memcpy is done asynchronously using an async handler. We store
// the function's information in the action but it is not actually a
// post action.
- if (auto Err = Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize))
+ if (auto Err =
+ Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize, NumTimes))
return Err;
// Make changes on this slot visible to the async handler's thread.
@@ -1464,7 +1477,12 @@ struct AMDGPUStreamTy {
std::tie(Curr, InputSignal) = consume(OutputSignal);
} else {
// All preceding operations completed, copy the memory synchronously.
- std::memcpy(Inter, Src, CopySize);
+ auto *InterPtr = Inter;
+ for (size_t I = 0; I < NumTimes; I++) {
+ std::memcpy(InterPtr, Src, CopySize);
+ InterPtr = reinterpret_cast<void *>(
+ reinterpret_cast<uintptr_t>(InterPtr) + CopySize);
+ }
// Return the second signal because it will not be used.
OutputSignals[1]->decreaseUseCount();
@@ -1481,11 +1499,11 @@ struct AMDGPUStreamTy {
if (InputSignal && InputSignal->load()) {
hsa_signal_t InputSignalRaw = InputSignal->get();
return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter,
- Agent, CopySize, 1, &InputSignalRaw,
- OutputSignal->get());
+ Agent, CopySize * NumTimes, 1,
+ &InputSignalRaw, OutputSignal->get());
}
return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter,
- Agent, CopySize, 0, nullptr,
+ Agent, CopySize * NumTimes, 0, nullptr,
OutputSignal->get());
}
@@ -2611,26 +2629,73 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
Error dataFillImpl(void *TgtPtr, const void *PatternPtr, int64_t PatternSize,
int64_t Size,
AsyncInfoWrapperTy &AsyncInfoWrapper) override {
- hsa_status_t Status;
+ // Fast case, where we can use the 4 byte hsa_amd_memory_fill
+ if (Size % 4 == 0 &&
+ (PatternSize == 4 || PatternSize == 2 || PatternSize == 1)) {
+ uint32_t Pattern;
+ if (PatternSize == 1) {
+ auto *Byte = reinterpret_cast<const uint8_t *>(PatternPtr);
+ Pattern = *Byte | *Byte << 8 | *Byte << 16 | *Byte << 24;
+ } else if (PatternSize == 2) {
+ auto *Word = reinterpret_cast<const uint16_t *>(PatternPtr);
+ Pattern = *Word | (*Word << 16);
+ } else if (PatternSize == 4) {
+ Pattern = *reinterpret_cast<const uint32_t *>(PatternPtr);
+ } else {
+ // Shouldn't be here if the pattern size is outwith those values
+ std::terminate();
+ }
- // We can use hsa_amd_memory_fill for this size, but it's not async so the
- // queue needs to be synchronized first
- if (PatternSize == 4) {
- if (AsyncInfoWrapper.hasQueue())
- if (auto Err = synchronize(AsyncInfoWrapper))
+ if (hasPendingWorkImpl(AsyncInfoWrapper)) {
+ AMDGPUStreamTy *Stream = nullptr;
+ if (auto Err = getStream(AsyncInfoWrapper, Stream))
return Err;
- Status = hsa_amd_memory_fill(TgtPtr,
- *static_cast<const uint32_t *>(PatternPtr),
- Size / PatternSize);
- if (auto Err =
- Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n"))
- return Err;
- } else {
- // TODO: Implement for AMDGPU. Most likely by doing the fill in pinned
- // memory and copying to the device in one go.
- return Plugin::error(ErrorCode::UNSUPPORTED, "Unsupported fill size");
+ struct MemFillArgsTy {
+ void *Dst;
+ uint32_t Pattern;
+ int64_t Size;
+ };
+ auto *Args = new MemFillArgsTy{TgtPtr, Pattern, Size / 4};
+ auto Fill = [](void *Data) {
+ MemFillArgsTy *Args = reinterpret_cast<MemFillArgsTy *>(Data);
+ assert(Args && "Invalid arguments");
+
+ auto Status =
+ hsa_amd_memory_fill(Args->Dst, Args->Pattern, Args->Size);
+ delete Args;
+ auto Err =
+ Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n");
+ if (Err) {
+ FATAL_MESSAGE(1, "error performing async fill: %s",
+ toString(std::move(Err)).data());
+ }
+ };
+
+ // hsa_amd_memory_fill doesn't signal completion using a signal, so use
+ // the existing host callback logic to handle that instead
+ return Stream->pushHostCallback(Fill, Args);
+ } else {
+ // If there is no pending work, do the fill synchronously
+ auto Status = hsa_amd_memory_fill(TgtPtr, Pattern, Size / 4);
+ return Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n");
+ }
}
+
+ // Slow case; allocate an appropriate memory size and enqueue copies
+ void *PinnedPtr = nullptr;
+ AMDGPUMemoryManagerTy &PinnedMemoryManager =
+ HostDevice.getPinnedMemoryManager();
+ if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedPtr))
+ return Err;
+
+ AMDGPUStreamTy *Stream = nullptr;
+ if (auto Err = getStream(AsyncInfoWrapper, Stream))
+ return Err;
+
+ return Stream->pushMemoryCopyH2DAsync(TgtPtr, PatternPtr, PinnedPtr,
+ PatternSize, PinnedMemoryManager,
+ Size / PatternSize);
}
/// Initialize the async info for interoperability purposes.
diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp
index fe7198a9c283f..0538e60f276e3 100644
--- a/offload/unittests/OffloadAPI/common/Fixtures.hpp
+++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp
@@ -89,6 +89,40 @@ template <typename Fn> inline void threadify(Fn body) {
}
}
+/// Enqueues a task to the queue that can be manually resolved.
+// It will block until `trigger` is called.
+struct ManuallyTriggeredTask {
+ std::mutex M;
+ std::condition_variable CV;
+ bool Flag = false;
+ ol_event_handle_t CompleteEvent;
+
+ ol_result_t enqueue(ol_queue_handle_t Queue) {
+ if (auto Err = olLaunchHostFunction(
+ Queue,
+ [](void *That) {
+ static_cast<ManuallyTriggeredTask *>(That)->wait();
+ },
+ this))
+ return Err;
+
+ return olCreateEvent(Queue, &CompleteEvent);
+ }
+
+ void wait() {
+ std::unique_lock<std::mutex> lk(M);
+ CV.wait_for(lk, std::chrono::milliseconds(1000), [&] { return Flag; });
+ EXPECT_TRUE(Flag);
+ }
+
+ ol_result_t trigger() {
+ Flag = true;
+ CV.notify_one();
+
+ return olSyncEvent(CompleteEvent);
+ }
+};
+
struct OffloadTest : ::testing::Test {
ol_device_handle_t Host = TestEnvironment::getHostDevice();
};
diff --git a/offload/unittests/OffloadAPI/memory/olMemFill.cpp b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
index 1b0bafa202080..a84ed3d78eccf 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFill.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
@@ -10,75 +10,129 @@
#include <OffloadAPI.h>
#include <gtest/gtest.h>
-using olMemFillTest = OffloadQueueTest;
+struct olMemFillTest : OffloadQueueTest {
+ template <typename PatternTy, PatternTy PatternVal, size_t Size,
+ bool Block = false>
+ void test_body() {
+ ManuallyTriggeredTask Manual;
+
+ // Block/enqueue tests ensure that the test has been enqueued to a queue
+ // (rather than being done synchronously if the queue happens to be empty)
+ if constexpr (Block) {
+ ASSERT_SUCCESS(Manual.enqueue(Queue));
+ }
+
+ void *Alloc;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
+
+ PatternTy Pattern = PatternVal;
+ ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
+
+ if constexpr (Block) {
+ ASSERT_SUCCESS(Manual.trigger());
+ }
+ olSyncQueue(Queue);
+
+ size_t N = Size / sizeof(Pattern);
+ for (size_t i = 0; i < N; i++) {
+ PatternTy *AllocPtr = reinterpret_cast<PatternTy *>(Alloc);
+ ASSERT_EQ(AllocPtr[i], Pattern);
+ }
+
+ olMemFree(Alloc);
+ }
+};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest);
-TEST_P(olMemFillTest, Success8) {
- constexpr size_t Size = 1024;
- void *Alloc;
- ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
-
- uint8_t Pattern = 0x42;
- ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
-
- olSyncQueue(Queue);
+TEST_P(olMemFillTest, Success8) { test_body<uint8_t, 0x42, 1024>(); }
+TEST_P(olMemFillTest, Success8NotMultiple4) {
+ test_body<uint8_t, 0x42, 1023>();
+}
+TEST_P(olMemFillTest, Success8Enqueue) {
+ test_body<uint8_t, 0x42, 1024, true>();
+}
+TEST_P(olMemFillTest, Success8NotMultiple4Enqueue) {
+ test_body<uint8_t, 0x42, 1023, true>();
+}
- size_t N = Size / sizeof(Pattern);
- for (size_t i = 0; i < N; i++) {
- uint8_t *AllocPtr = reinterpret_cast<uint8_t *>(Alloc);
- ASSERT_EQ(AllocPtr[i], Pattern);
- }
+TEST_P(olMemFillTest, Success16) { test_body<uint8_t, 0x42, 1024>(); }
+TEST_P(olMemFillTest, Success16NotMultiple4) {
+ test_body<uint16_t, 0x4243, 1022>();
+}
+TEST_P(olMemFillTest, Success16Enqueue) {
+ test_body<uint8_t, 0x42, 1024, true>();
+}
+TEST_P(olMemFillTest, Success16NotMultiple4Enqueue) {
+ test_body<uint16_t, 0x4243, 1022, true>();
+}
- olMemFree(Alloc);
+TEST_P(olMemFillTest, Success32) { test_body<uint32_t, 0xDEADBEEF, 1024>(); }
+TEST_P(olMemFillTest, Success32Enqueue) {
+ test_body<uint32_t, 0xDEADBEEF, 1024, true>();
}
-TEST_P(olMemFillTest, Success16) {
+TEST_P(olMemFillTest, SuccessLarge) {
constexpr size_t Size = 1024;
void *Alloc;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
- uint16_t Pattern = 0x4242;
+ struct PatternT {
+ uint64_t A;
+ uint64_t B;
+ } Pattern{UINT64_MAX, UINT64_MAX};
+
ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
olSyncQueue(Queue);
size_t N = Size / sizeof(Pattern);
for (size_t i = 0; i < N; i++) {
- uint16_t *AllocPtr = reinterpret_cast<uint16_t *>(Alloc);
- ASSERT_EQ(AllocPtr[i], Pattern);
+ PatternT *AllocPtr = reinterpret_cast<PatternT *>(Alloc);
+ ASSERT_EQ(AllocPtr[i].A, UINT64_MAX);
+ ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
}
olMemFree(Alloc);
}
-TEST_P(olMemFillTest, Success32) {
+TEST_P(olMemFillTest, SuccessLargeEnqueue) {
constexpr size_t Size = 1024;
void *Alloc;
+ ManuallyTriggeredTask Manual;
+ ASSERT_SUCCESS(Manual.enqueue(Queue));
+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
- uint32_t Pattern = 0xDEADBEEF;
+ struct PatternT {
+ uint64_t A;
+ uint64_t B;
+ } Pattern{UINT64_MAX, UINT64_MAX};
+
ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
+ Manual.trigger();
olSyncQueue(Queue);
size_t N = Size / sizeof(Pattern);
for (size_t i = 0; i < N; i++) {
- uint32_t *AllocPtr = reinterpret_cast<uint32_t *>(Alloc);
- ASSERT_EQ(AllocPtr[i], Pattern);
+ PatternT *AllocPtr = reinterpret_cast<PatternT *>(Alloc);
+ ASSERT_EQ(AllocPtr[i].A, UINT64_MAX);
+ ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
}
olMemFree(Alloc);
}
-TEST_P(olMemFillTest, SuccessLarge) {
- constexpr size_t Size = 1024;
+TEST_P(olMemFillTest, SuccessLargeByteAligned) {
+ constexpr size_t Size = 17 * 64;
void *Alloc;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
- struct PatternT {
+ struct __attribute__((packed)) PatternT {
uint64_t A;
uint64_t B;
- } Pattern{UINT64_MAX, UINT64_MAX};
+ uint8_t C;
+ } Pattern{UINT64_MAX, UINT64_MAX, 255};
ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
@@ -89,14 +143,18 @@ TEST_P(olMemFillTest, SuccessLarge) {
PatternT *AllocPtr = reinterpret_cast<PatternT *>(Alloc);
ASSERT_EQ(AllocPtr[i].A, UINT64_MAX);
ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
+ ASSERT_EQ(AllocPtr[i].C, 255);
}
olMemFree(Alloc);
}
-TEST_P(olMemFillTest, SuccessLargeByteAligned) {
+TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
constexpr size_t Size = 17 * 64;
void *Alloc;
+ ManuallyTriggeredTask Manual;
+ ASSERT_SUCCESS(Manual.enqueue(Queue));
+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, Size, &Alloc));
struct __attribute__((packed)) PatternT {
@@ -107,6 +165,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) {
ASSERT_SUCCESS(olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
+ Manual.trigger();
olSyncQueue(Queue);
size_t N = Size / sizeof(Pattern);
|
if (Err) { | ||
FATAL_MESSAGE(1, "error performing async fill: %s", | ||
toString(std::move(Err)).data()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should try as hard as possible not to just roll over and die inside of the plugin. We don't do a great job of it so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but that would require liboffload/PluginInterface to have some kind of async error handling. Maybe olPlatformPopAsyncError()
that pulls from a queue in the platform that the async handlers can push to?
Anyway, I think that should be a separate task. What I'm doing here is the same as asyncActionCallback
, and such a solution would touch both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's probably something we should try to figure out at some point. I forget if we discussed it before, but we definitely don't want anything like cudaGetLastError
. I could see the result of a stream result in an error and allowing users to check for it through an event or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have discussed it before, that memory is lost to me. What are your particular grievances with cudaGetLastError
? Would having olSyncEvent
/Queue
return any pending async errors work? What about adding an error callback function via something like olSetAsyncErrorHandler([](ol_result_t Result) {...})
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same reason that errno
is a nightmare to deal with. Global errors aren't a good fit for an API that is supposed to be asynchronous. It's confusing who should own it and it's difficult to reason with, hence why CUDA had to provide both cudaGetLastError
and cudaPeekLastError
to let people figure out if they should even be using the error on the stack.
I think HSA uses callbacks in a similar way, but in my head it would probably be best if we just made it an event or something on the stream the user can query if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of linking errors to streams (e.g. the stream is put into an error state and syncQueue
somehow returns an error code), but that has problems with cross-queue waits. For example, if queue A waits on queue B, and queue B reports an error, what happens with queue A?
I think this is something that warrants its own thoughts and discussion. Can I merge this as is (assuming no other issues) just to unblock AMD and look into a design for error handling as a separate change?
@@ -89,6 +89,40 @@ template <typename Fn> inline void threadify(Fn body) { | |||
} | |||
} | |||
|
|||
/// Enqueues a task to the queue that can be manually resolved. | |||
// It will block until `trigger` is called. | |||
struct ManuallyTriggeredTask { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely sure why we need this but I'll leave it be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AMD driver has two different code paths depending on whether the queue is empty or not. This struct allows the test framework to "hold" a task in the queue making it non-empty so that the non-empty path can be tested.
@arsenm I'm assuming you're happy with using |
No description provided.