forked from rapidsai/legate-boost
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.h
More file actions
74 lines (65 loc) · 2.63 KB
/
utils.h
File metadata and controls
74 lines (65 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#pragma once
#include "legate_library.h"
#include <core/type/type_info.h>
namespace legateboost {
extern Legion::Logger logger;
inline void expect(bool condition, std::string message, std::string file, int line)
{
if (!condition) { throw std::runtime_error(file + "(" + std::to_string(line) + "): " + message); }
}
#define EXPECT(condition, message) (expect(condition, message, __FILE__, __LINE__))
template <int AXIS, typename ShapeAT, typename ShapeBT>
void expect_axis_aligned(const ShapeAT& a, const ShapeBT& b, std::string file, int line)
{
expect((a.lo[AXIS] == b.lo[AXIS]) && (a.hi[AXIS] == b.hi[AXIS]),
"Inconsistent axis alignment.",
file,
line);
}
#define EXPECT_AXIS_ALIGNED(axis, shape_a, shape_b) \
(expect_axis_aligned<axis>(shape_a, shape_b, __FILE__, __LINE__))
template <typename ShapeT>
void expect_is_broadcast(const ShapeT& shape, std::string file, int line)
{
for (int i = 0; i < sizeof(shape.lo.x) / sizeof(shape.lo[0]); i++) {
std::stringstream ss;
ss << "Expected a broadcast store. Got shape: " << shape << ".";
expect(shape.lo[i] == 0, ss.str(), file, line);
}
}
#define EXPECT_IS_BROADCAST(shape) (expect_is_broadcast(shape, __FILE__, __LINE__))
template <typename Functor, typename... Fnargs>
constexpr decltype(auto) type_dispatch_float(legate::Type::Code code, Functor f, Fnargs&&... args)
{
switch (code) {
case legate::Type::Code::FLOAT16: {
return f.template operator()<legate::Type::Code::FLOAT16>(std::forward<Fnargs>(args)...);
}
case legate::Type::Code::FLOAT32: {
return f.template operator()<legate::Type::Code::FLOAT32>(std::forward<Fnargs>(args)...);
}
case legate::Type::Code::FLOAT64: {
return f.template operator()<legate::Type::Code::FLOAT64>(std::forward<Fnargs>(args)...);
}
default: break;
}
EXPECT(false, "Expected floating point data.");
return f.template operator()<legate::Type::Code::FLOAT32>(std::forward<Fnargs>(args)...);
}
void SumAllReduce(legate::TaskContext context, double* x, int count);
} // namespace legateboost