20 int64_t offsetInBytes = 0;
21 Value current = value;
26 if (
auto blockArg = dyn_cast<BlockArgument>(current)) {
30 Operation *defOp = current.getDefiningOp();
36 if (
auto castOp = dyn_cast<memref::CastOp>(defOp)) {
37 current = castOp.getSource();
42 if (
auto reinterpretOp = dyn_cast<memref::ReinterpretCastOp>(defOp)) {
44 dyn_cast<MemRefType>(reinterpretOp.getSource().getType());
50 if (
auto strided = dyn_cast<StridedLayoutAttr>(sourceType.getLayout())) {
51 for (int64_t stride : strided.getStrides()) {
59 current = reinterpretOp.getSource();
64 if (
auto subviewOp = dyn_cast<memref::SubViewOp>(defOp)) {
66 if (!subviewOp.getStaticOffsets().empty() &&
67 subviewOp.getStaticOffsets()[0] == ShapedType::kDynamic) {
70 if (!subviewOp.getStaticSizes().empty() &&
71 subviewOp.getStaticSizes()[0] == ShapedType::kDynamic) {
74 if (!subviewOp.getStaticStrides().empty() &&
75 subviewOp.getStaticStrides()[0] == ShapedType::kDynamic) {
80 if (subviewOp.getSourceType().getRank() != 1 ||
81 subviewOp.getType().getRank() != 1) {
86 if (!subviewOp.getStaticStrides().empty() &&
87 subviewOp.getStaticStrides()[0] != 1) {
92 auto sourceType = subviewOp.getSourceType();
93 unsigned elemSizeInBits =
94 sourceType.getElementType().getIntOrFloatBitWidth();
95 if (elemSizeInBits % 8 != 0) {
98 unsigned elemSizeInBytes = elemSizeInBits / 8;
99 int64_t offsetInElements = subviewOp.getStaticOffsets()[0];
100 offsetInBytes += offsetInElements * elemSizeInBytes;
102 current = subviewOp.getSource();
116 ArrayRef<uint32_t> words) {
117 uint32_t num_words = words.size();
118 MemRefType memrefType = MemRefType::get({num_words}, builder.getI32Type());
119 TensorType tensorType =
120 RankedTensorType::get({num_words}, builder.getI32Type());
121 memref::GlobalOp global =
nullptr;
122 auto initVal = DenseElementsAttr::get<uint32_t>(tensorType, words);
123 auto otherGlobals = dev.getOps<memref::GlobalOp>();
124 for (
auto g : otherGlobals) {
125 if (g.getType() != memrefType)
127 auto otherValue = g.getInitialValue();
130 if (*otherValue != initVal)
136 std::string name =
"blockwrite_data_";
137 while (dev.lookupSymbol(name + std::to_string(cachedId)))
139 name += std::to_string(cachedId);
140 global = memref::GlobalOp::create(builder, loc, name,
141 builder.getStringAttr(
"private"),
142 memrefType, initVal,
true,
nullptr);