您的位置:首页 > 教育 > 培训 > 内存分配抽象定义二

内存分配抽象定义二

2025/7/12 16:38:30 来源:https://blog.csdn.net/qq_35054151/article/details/139865223  浏览:    关键词:内存分配抽象定义二
#include <iostream>
#include <memory>
#include <stdexcept>
#include <cuda_runtime.h>template <typename T>
struct MemoryDeleter {bool UseCUDA; // 成员变量用于标记是否使用CUDAMemoryDeleter(bool useCUDA) : UseCUDA(useCUDA) {} // 构造函数初始化UseCUDAvoid operator()(T* ptr) {if (UseCUDA) {cudaError_t cudaStatus = cudaFree(ptr);if (cudaStatus != cudaSuccess) {std::cerr << "CUDA memory free error: " << cudaGetErrorString(cudaStatus) << std::endl;}}else {delete[] ptr; // 使用delete[]释放CPU内存}}
};template <typename T, bool UseCUDA>
using SharedMemoryPtr = std::conditional_t<UseCUDA, std::shared_ptr<T>, std::unique_ptr<T[], MemoryDeleter<T>>>;template <typename T, bool UseCUDA>
class MemoryManager {
public:static SharedMemoryPtr<T, UseCUDA> Allocate(size_t size);static void Set(T* ptr, int value, size_t size);static void Copy(T* dest, const T* src, size_t size);
};template <typename T, bool UseCUDA>
SharedMemoryPtr<T, UseCUDA> MemoryManager<T, UseCUDA>::Allocate(size_t size) {T* ptr = nullptr;if constexpr (UseCUDA) {cudaMalloc((T**)&ptr, size * sizeof(T));}else {ptr = new T[size];}return SharedMemoryPtr<T, UseCUDA>(ptr, MemoryDeleter<T>(UseCUDA));
}template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Set(T* ptr, int value, size_t size) {if constexpr (UseCUDA) {cudaMemset(ptr, value, size * sizeof(T));}else {for (size_t i = 0; i < size; ++i) {ptr[i] = static_cast<T>(value);}}
}template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Copy(T* dest, const T* src, size_t size) {if constexpr (UseCUDA) {cudaMemcpy(dest, src, size * sizeof(T), cudaMemcpyHostToDevice);}else {memcpy(dest, src, size * sizeof(T));}
}int main() {int size = 512 * 512 * 500;SharedMemoryPtr<float, true> ptr = MemoryManager<float, true>::Allocate(size);int value = 0;MemoryManager<float, true>::Set(ptr.get(), value, size);// float hostData[512 * 512 * 100]={ 0 };float* hostData = new float[size];for (int i = 0; i < size; ++i) {hostData[i] = static_cast<float>(i);}MemoryManager<float, true>::Copy(ptr.get(), hostData, size);//ptr.reset();return 0;
}

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com