MLIR-AIE
AIEUtils.cpp
Go to the documentation of this file.
1//===- AIEUtils.cpp ---------------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
12
13using namespace mlir;
14using namespace xilinx;
15
16static unsigned cachedId = 0;
17
18std::optional<AIEX::SubviewTraceResult>
20 int64_t offsetInBytes = 0;
21 Value current = value;
22
23 // Walk through the chain of operations until we reach a block argument
24 while (current) {
25 // Check if we've reached a block argument
26 if (auto blockArg = dyn_cast<BlockArgument>(current)) {
27 return SubviewTraceResult{blockArg, offsetInBytes};
28 }
29
30 Operation *defOp = current.getDefiningOp();
31 if (!defOp) {
32 return std::nullopt;
33 }
34
35 // Handle memref.cast (just pass through)
36 if (auto castOp = dyn_cast<memref::CastOp>(defOp)) {
37 current = castOp.getSource();
38 continue;
39 }
40
41 // Handle memref.reinterpret_cast (validate and pass through)
42 if (auto reinterpretOp = dyn_cast<memref::ReinterpretCastOp>(defOp)) {
43 auto sourceType =
44 dyn_cast<MemRefType>(reinterpretOp.getSource().getType());
45 if (!sourceType) {
46 return std::nullopt;
47 }
48
49 // Validate that source is contiguous (all strides must be 1)
50 if (auto strided = dyn_cast<StridedLayoutAttr>(sourceType.getLayout())) {
51 for (int64_t stride : strided.getStrides()) {
52 if (stride != 1) {
53 return std::nullopt; // Non-contiguous memory, cannot safely
54 // reinterpret
55 }
56 }
57 }
58
59 current = reinterpretOp.getSource();
60 continue;
61 }
62
63 // Handle memref.subview (accumulate offset and validate)
64 if (auto subviewOp = dyn_cast<memref::SubViewOp>(defOp)) {
65 // Verify static offsets, sizes, strides
66 if (!subviewOp.getStaticOffsets().empty() &&
67 subviewOp.getStaticOffsets()[0] == ShapedType::kDynamic) {
68 return std::nullopt;
69 }
70 if (!subviewOp.getStaticSizes().empty() &&
71 subviewOp.getStaticSizes()[0] == ShapedType::kDynamic) {
72 return std::nullopt;
73 }
74 if (!subviewOp.getStaticStrides().empty() &&
75 subviewOp.getStaticStrides()[0] == ShapedType::kDynamic) {
76 return std::nullopt;
77 }
78
79 // Only support rank-1 subviews
80 if (subviewOp.getSourceType().getRank() != 1 ||
81 subviewOp.getType().getRank() != 1) {
82 return std::nullopt;
83 }
84
85 // Only support stride of 1 (contiguous)
86 if (!subviewOp.getStaticStrides().empty() &&
87 subviewOp.getStaticStrides()[0] != 1) {
88 return std::nullopt;
89 }
90
91 // Calculate and accumulate offset in bytes
92 auto sourceType = subviewOp.getSourceType();
93 unsigned elemSizeInBits =
94 sourceType.getElementType().getIntOrFloatBitWidth();
95 if (elemSizeInBits % 8 != 0) {
96 return std::nullopt;
97 }
98 unsigned elemSizeInBytes = elemSizeInBits / 8;
99 int64_t offsetInElements = subviewOp.getStaticOffsets()[0];
100 offsetInBytes += offsetInElements * elemSizeInBytes;
101
102 current = subviewOp.getSource();
103 continue;
104 }
105
106 // Encountered an unsupported operation
107 return std::nullopt;
108 }
109
110 return std::nullopt;
111}
112
113memref::GlobalOp AIEX::getOrCreateDataMemref(OpBuilder &builder,
114 AIE::DeviceOp dev,
115 mlir::Location loc,
116 ArrayRef<uint32_t> words) {
117 uint32_t num_words = words.size();
118 MemRefType memrefType = MemRefType::get({num_words}, builder.getI32Type());
119 TensorType tensorType =
120 RankedTensorType::get({num_words}, builder.getI32Type());
121 memref::GlobalOp global = nullptr;
122 auto initVal = DenseElementsAttr::get<uint32_t>(tensorType, words);
123 auto otherGlobals = dev.getOps<memref::GlobalOp>();
124 for (auto g : otherGlobals) {
125 if (g.getType() != memrefType)
126 continue;
127 auto otherValue = g.getInitialValue();
128 if (!otherValue)
129 continue;
130 if (*otherValue != initVal)
131 continue;
132 global = g;
133 break;
134 }
135 if (!global) {
136 std::string name = "blockwrite_data_";
137 while (dev.lookupSymbol(name + std::to_string(cachedId)))
138 cachedId++;
139 name += std::to_string(cachedId);
140 global = memref::GlobalOp::create(builder, loc, name,
141 builder.getStringAttr("private"),
142 memrefType, initVal, true, nullptr);
143 }
144 return global;
145}
std::optional< SubviewTraceResult > traceSubviewToBlockArgument(Value value)
Definition AIEUtils.cpp:19
memref::GlobalOp getOrCreateDataMemref(OpBuilder &builder, AIE::DeviceOp dev, mlir::Location loc, ArrayRef< uint32_t > words)
Definition AIEUtils.cpp:113