-
Notifications
You must be signed in to change notification settings - Fork 329
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
70 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include <hip/hip_runtime.h> | ||
#include <iostream> | ||
#include <algorithm> | ||
#include <type_traits> | ||
#include <cstdlib> // For std::abort | ||
#include <typeinfo> // For typeid | ||
|
||
std::string demangle(const char* name) { | ||
if (std::string(name) == "i") return "int"; | ||
else if (std::string(name) == "f") return "float"; | ||
else if (std::string(name) == "d") return "double"; | ||
else if (std::string(name) == "j") return "unsigned int"; | ||
else if (std::string(name) == "l") return "long"; | ||
else if (std::string(name) == "m") return "unsigned long"; | ||
else if (std::string(name) == "x") return "long long"; | ||
else if (std::string(name) == "y") return "unsigned long long"; | ||
else return std::string(name); | ||
} | ||
|
||
void checkHipCall(hipError_t status, const char* msg) { | ||
if (status != hipSuccess) { | ||
std::cerr << "HIP Error: " << msg << " - " << hipGetErrorString(status) << std::endl; | ||
std::abort(); | ||
} | ||
} | ||
|
||
template<typename T1, typename T2> | ||
void compareResults(T1 hipResult, T2 stdResult, const std::string& testName) { | ||
using CommonType = typename std::common_type<T1, T2>::type; | ||
if (static_cast<CommonType>(hipResult) != static_cast<CommonType>(stdResult)) { | ||
std::cerr << testName << " mismatch: HIP result " << hipResult << " (" << demangle(typeid(hipResult).name()) << "), std result " << stdResult << " (" << demangle(typeid(stdResult).name()) << ")" << std::endl; | ||
std::abort(); | ||
} | ||
} | ||
|
||
template<typename T1, typename T2> | ||
void runTest(T1 a, T2 b) { | ||
std::cout << "\nTesting with values: " << a << " (" << demangle(typeid(a).name()) << ") and " << b << " (" << demangle(typeid(b).name()) << ")" << std::endl; | ||
|
||
// Using std::min and std::max explicitly for host code to ensure clarity and correctness | ||
using CommonType = typename std::common_type<T1, T2>::type; | ||
CommonType stdMinResult = std::min<CommonType>(a, b); | ||
CommonType stdMaxResult = std::max<CommonType>(a, b); | ||
std::cout << "Host std::min result: " << stdMinResult << " (Type: " << demangle(typeid(stdMinResult).name()) << ")" << std::endl; | ||
std::cout << "Host std::max result: " << stdMaxResult << " (Type: " << demangle(typeid(stdMaxResult).name()) << ")" << std::endl; | ||
|
||
// Using HIP's global min/max functions | ||
CommonType hipMinResult = min(a, b); // Note: This directly uses HIP's min, assuming it's correctly overloaded for host code | ||
CommonType hipMaxResult = max(a, b); // Note: This directly uses HIP's max, assuming it's correctly overloaded for host code | ||
std::cout << "Host HIP min result: " << hipMinResult << " (Type: " << demangle(typeid(hipMinResult).name()) << ")" << std::endl; | ||
std::cout << "Host HIP max result: " << hipMaxResult << " (Type: " << demangle(typeid(hipMaxResult).name()) << ")" << std::endl; | ||
|
||
// Ensure the host HIP and std results match | ||
compareResults(hipMinResult, stdMinResult, "HIP vs std min"); | ||
compareResults(hipMaxResult, stdMaxResult, "HIP vs std max"); | ||
} | ||
|
||
int main() { | ||
checkHipCall(hipSetDevice(0), "hipSetDevice failed"); | ||
|
||
runTest(10uLL, -5LL); // Testing with unsigned int and long long | ||
runTest(-15, 20u); // Testing with int and unsigned int | ||
runTest(2147483647, 2147483648u); // Testing with int and unsigned int | ||
runTest(-922337203685477580LL, 922337203685477580uLL); // Testing with long long and unsigned long long | ||
runTest(2.5f, 3.14159); // Testing with float and double | ||
|
||
std::cout << "\nPass\n"; // Output "Pass" at the end if all tests pass without aborting | ||
return 0; | ||
} |