MLIR-AIE
AIETransformBfpTypes.cpp
Go to the documentation of this file.
1//===- AIETransformBfpTypes.cpp --------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// Copyright (C) 2025, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
15
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/LLVM.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/IR/Type.h"
25#include "llvm/Support/raw_ostream.h"
26
27namespace xilinx::AIEX {
28#define GEN_PASS_DEF_AIETRANSFORMBFPTYPES
29#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc"
30} // namespace xilinx::AIEX
31
32#define DEBUG_TYPE "transform-bfp-types"
33
34using namespace mlir;
35using namespace xilinx;
36using namespace xilinx::AIE;
37using namespace xilinx::AIEX;
38
39using namespace mlir;
40
41class BfpToIntegerConverter : public mlir::TypeConverter {
42public:
44 addTypeAttributeConversion([&](Type type, Attribute attr) {
45 auto newType = convertType(type);
46 if (!newType) {
47 llvm::errs() << "Failed to convert type: " << type << "\n";
48 return AttributeConversionResult::abort();
49 }
50 return AttributeConversionResult(TypeAttr::get(newType));
51 });
52
53 // Note that the most recently added conversions will be invoked first
54
55 // Leave other types unchanged
56 addConversion([](Type type) -> std::optional<Type> { return type; });
57
58 // Add a conversion for bfpTypes to an integer type
59 addConversion([&](BlockFloatType blockType) -> std::optional<IntegerType> {
60 bool isSupported =
61 targetModel.isSupportedBlockFormat(blockType.getBlockType().str());
62 if (!isSupported) {
63 llvm::errs() << "Block type " << blockType.getBlockType()
64 << " is not supported in the specified model\n";
65 // Note that returning a nullptr here will stop the conversion while
66 // returning a std::nullopt will allow the converter to keep trying the
67 // remaining conversions (thus reaching the default one in this case)
68 return nullptr;
69 }
70
71 return mlir::IntegerType::get(blockType.getContext(),
72 blockType.getTotalSizeInBits());
73 });
74
75 // Add a conversion for MemRefType
76 addConversion([&](MemRefType memRefType) -> std::optional<MemRefType> {
77 auto newElementType = convertType(memRefType.getElementType());
78 if (!newElementType) {
79 llvm::errs() << "Failed to convert memref element type\n";
80 return nullptr;
81 }
82 return MemRefType::get(memRefType.getShape(), newElementType,
83 memRefType.getLayout(),
84 memRefType.getMemorySpace());
85 });
86
87 // Add a conversion for ObjectFifoType
88 addConversion([&](AIEObjectFifoType objectFifoType)
89 -> std::optional<AIEObjectFifoType> {
90 auto newElementType = convertType(objectFifoType.getElementType());
91 if (!newElementType) {
92 llvm::errs() << "Failed to convert ObjectFifoType element type\n";
93 return nullptr;
94 }
95
96 if (auto newMemRef = dyn_cast<MemRefType>(newElementType))
97 return AIEObjectFifoType::get(objectFifoType.getContext(), newMemRef);
98
99 llvm::errs()
100 << "ObjectFifoType converted element type is not a MemRefType\n";
101 return nullptr;
102 });
103
104 // Add a conversion for ObjectFifoSubviewType
105 addConversion([&](AIEObjectFifoSubviewType objectFifoSubviewType)
106 -> std::optional<AIEObjectFifoSubviewType> {
107 auto newElementType = convertType(objectFifoSubviewType.getElementType());
108 if (!newElementType) {
109 llvm::errs()
110 << "Failed to convert ObjectFifoSubviewType element type\n";
111 return nullptr;
112 }
113
114 if (auto newMemRef = dyn_cast<MemRefType>(newElementType))
115 return AIEObjectFifoSubviewType::get(objectFifoSubviewType.getContext(),
116 newMemRef);
117
118 llvm::errs()
119 << "ObjectFifoSubviewType element type is not a MemRefType\n";
120 return nullptr;
121 });
122
123 // Add a conversion for FunctionType
124 addConversion([&](FunctionType funcType) -> std::optional<FunctionType> {
125 llvm::SmallVector<Type> newInputTypes;
126 auto check = convertTypes(funcType.getInputs(), newInputTypes);
127 if (check.failed()) {
128 llvm::errs() << "Failed to convert function input types\n";
129 return nullptr;
130 }
131
132 llvm::SmallVector<Type> newOutputTypes;
133 check = convertTypes(funcType.getResults(), newOutputTypes);
134 if (check.failed()) {
135 llvm::errs() << "Failed to convert function output types\n";
136 return nullptr;
137 }
138
139 return FunctionType::get(funcType.getContext(), newInputTypes,
140 newOutputTypes);
141 });
142
143 // Add conversions for other types as needed (llvm arrays?)
144 }
145};
146
148public:
149 BfpToIntegerConversionPattern(TypeConverter &typeConverter,
150 MLIRContext *context, bool &conversionFailed)
151 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context),
152 conversionFailed(conversionFailed) {}
153
154 LogicalResult
155 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
156 ConversionPatternRewriter &rewriter) const override {
157
158 // The objective is to replace all bfp operations by an integer of the
159 // appropriate width. This pass currently does not have any other
160 // functionality.
161
162 // Operation results
163 for (auto result : op->getResults()) {
164 auto conversion = typeConverter->convertType(result.getType());
165 if (!conversion) {
166 conversionFailed = true;
167 return op->emitError()
168 << "Failed to convert result type: " << result.getType();
169 }
170 result.setType(conversion);
171 }
172
173 // Operation operands
174 for (auto operand : op->getOperands()) {
175 auto conversion = typeConverter->convertType(operand.getType());
176 if (!conversion) {
177 conversionFailed = true;
178 return op->emitError()
179 << "Failed to convert operand type: " << operand.getType();
180 }
181 operand.setType(conversion);
182 }
183
184 // Operation attributes
185 // Note that the attribute list is immutable and
186 // needs to be recreated from scratch. Also note that type attributes cannot
187 // access their type and must therefore be managed through
188 // the convertTypeAttribute conversion instead
189 SmallVector<NamedAttribute> newAttrs;
190 for (auto attr : op->getAttrs()) {
191 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
192 auto conversion = typeConverter->convertTypeAttribute(
193 typeAttr.getValue(), attr.getValue());
194 if (!conversion) {
195 conversionFailed = true;
196 return op->emitError()
197 << "Failed to convert attribute type: " << typeAttr.getValue();
198 }
199 newAttrs.push_back(NamedAttribute(attr.getName(), conversion.value()));
200 } else {
201 newAttrs.push_back(attr);
202 }
203 }
204 op->setAttrs(DictionaryAttr::get(op->getContext(), newAttrs));
205
206 return success();
207 }
208
209private:
210 bool &conversionFailed;
211};
212
215 AIETransformBfpTypesPass> {
216public:
217 void runOnOperation() override {
218 DeviceOp device = getOperation();
219 MLIRContext *context = device.getContext();
220
221 BfpToIntegerConverter typeConverter(device.getTargetModel());
222
223 // Set up an empty conversion target, since we have to iterate over all ops
224 ConversionTarget target(*context);
225
226 RewritePatternSet patterns(context);
227 bool conversionFailed = false;
228 patterns.add<BfpToIntegerConversionPattern>(typeConverter, context,
229 conversionFailed);
230
231 if (failed(applyPartialConversion(device, target, std::move(patterns))) ||
232 conversionFailed) {
233 signalPassFailure();
234 }
235 }
236};
237
238std::unique_ptr<OperationPass<DeviceOp>>
240 return std::make_unique<AIETransformBfpTypesPass>();
241}
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const override
BfpToIntegerConversionPattern(TypeConverter &typeConverter, MLIRContext *context, bool &conversionFailed)
BfpToIntegerConverter(const AIETargetModel &targetModel)
virtual bool isSupportedBlockFormat(std::string const &format) const
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIETransformBfpTypesPass()
Include the generated interface declarations.