-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathTutorial.cpp
More file actions
115 lines (94 loc) · 3.89 KB
/
Tutorial.cpp
File metadata and controls
115 lines (94 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "Tutorial.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#define GET_OP_CLASSES
#include "Tutorial.cpp.inc"
void mlir::tutorial::TutorialDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "Tutorial.cpp.inc"
>();
}
#include "TutorialDialect.cpp.inc"
namespace mlir::tutorial {
SmallVector<Range> DequantOp::getIterationDomain(OpBuilder &b) {
int64_t rank = getInput().getType().getRank();
OpFoldResult zero = b.getIndexAttr(0);
OpFoldResult one = b.getIndexAttr(1);
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(b, getLoc(), getInput());
SmallVector<Range> loopBounds(rank);
for (auto dim : llvm::seq<int64_t>(rank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = sizes[dim];
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> DequantOp::getLoopIteratorTypes() {
int64_t rank = getInput().getType().getRank();
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
LogicalResult DequantOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
}
FailureOr<TilingResult> DequantOp::getTiledImplementation(
OpBuilder &b, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
int64_t rank = getInput().getType().getRank();
SmallVector<OpFoldResult> strides(rank, b.getI64IntegerAttr(1));
auto inputTile = b.create<tensor::ExtractSliceOp>(loc, getInput(), offsets,
sizes, strides);
auto scaleTile = b.create<tensor::ExtractSliceOp>(loc, getScale(), offsets,
sizes, strides);
Type resultType = inputTile.getResultType();
Operation *tiledOp =
mlir::clone(b, getOperation(), {resultType}, {inputTile, scaleTile});
return TilingResult{{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
{inputTile, scaleTile}};
}
LogicalResult DequantOp::getIterationDomainTileFromResultTile(
OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
iterDomainOffsets = llvm::to_vector(offsets);
iterDomainSizes = llvm::to_vector(sizes);
return success();
}
FailureOr<TilingResult> DequantOp::generateResultTileValue(
OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromResultTile(
b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
return getTiledImplementation(b, mappedOffsets, mappedSizes);
}
LogicalResult DequantOp::getIterationDomainTileFromOperandTile(
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
iterDomainOffsets = llvm::to_vector(offsets);
iterDomainSizes = llvm::to_vector(sizes);
return success();
}
FailureOr<TilingResult> DequantOp::getTiledImplementationFromOperandTile(
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromOperandTile(
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
return getTiledImplementation(b, mappedOffsets, mappedSizes);
}
} // namespace mlir::tutorial