MLIR-AIE
AIETransformBfpTypes.cpp
Go to the documentation of this file.
1//===- AIETransformBfpTypes.cpp --------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// Copyright (C) 2025, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10
15
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/LLVM.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/IR/Type.h"
25#include "llvm/Support/raw_ostream.h"
26
27#define DEBUG_TYPE "transform-bfp-types"
28
29using namespace mlir;
30using namespace xilinx;
31using namespace xilinx::AIE;
32using namespace xilinx::AIEX;
33
34using namespace mlir;
35
36class BfpToIntegerConverter : public mlir::TypeConverter {
37public:
39 addTypeAttributeConversion([&](Type type, Attribute attr) {
40 auto newType = convertType(type);
41 if (!newType) {
42 llvm::errs() << "Failed to convert type: " << type << "\n";
43 return AttributeConversionResult::abort();
44 }
45 return AttributeConversionResult(TypeAttr::get(newType));
46 });
47
48 // Note that the most recently added conversions will be invoked first
49
50 // Leave other types unchanged
51 addConversion([](Type type) -> std::optional<Type> { return type; });
52
53 // Add a conversion for bfpTypes to an integer type
54 addConversion([&](BlockFloatType blockType) -> std::optional<IntegerType> {
55 bool isSupported =
56 targetModel.isSupportedBlockFormat(blockType.getBlockType().str());
57 if (!isSupported) {
58 llvm::errs() << "Block type " << blockType.getBlockType()
59 << " is not supported in the specified model\n";
60 // Note that returning a nullptr here will stop the conversion while
61 // returning a std::nullopt will allow the converter to keep trying the
62 // remaining conversions (thus reaching the default one in this case)
63 return nullptr;
64 }
65
66 return mlir::IntegerType::get(blockType.getContext(), blockType.getTotalSizeInBits());
67 });
68
69 // Add a conversion for MemRefType
70 addConversion([&](MemRefType memRefType) -> std::optional<MemRefType> {
71 auto newElementType = convertType(memRefType.getElementType());
72 if (!newElementType) {
73 llvm::errs() << "Failed to convert memref element type\n";
74 return nullptr;
75 }
76 return MemRefType::get(memRefType.getShape(), newElementType,
77 memRefType.getLayout(),
78 memRefType.getMemorySpace());
79 });
80
81 // Add a conversion for ObjectFifoType
82 addConversion([&](AIEObjectFifoType objectFifoType)
83 -> std::optional<AIEObjectFifoType> {
84 auto newElementType = convertType(objectFifoType.getElementType());
85 if (!newElementType) {
86 llvm::errs() << "Failed to convert ObjectFifoType element type\n";
87 return nullptr;
88 }
89
90 if (auto newMemRef = dyn_cast<MemRefType>(newElementType))
91 return AIEObjectFifoType::get(objectFifoType.getContext(), newMemRef);
92
93 llvm::errs()
94 << "ObjectFifoType converted element type is not a MemRefType\n";
95 return nullptr;
96 });
97
98 // Add a conversion for ObjectFifoSubviewType
99 addConversion([&](AIEObjectFifoSubviewType objectFifoSubviewType)
100 -> std::optional<AIEObjectFifoSubviewType> {
101 auto newElementType = convertType(objectFifoSubviewType.getElementType());
102 if (!newElementType) {
103 llvm::errs()
104 << "Failed to convert ObjectFifoSubviewType element type\n";
105 return nullptr;
106 }
107
108 if (auto newMemRef = dyn_cast<MemRefType>(newElementType))
109 return AIEObjectFifoSubviewType::get(objectFifoSubviewType.getContext(),
110 newMemRef);
111
112 llvm::errs()
113 << "ObjectFifoSubviewType element type is not a MemRefType\n";
114 return nullptr;
115 });
116
117 // Add a conversion for FunctionType
118 addConversion([&](FunctionType funcType) -> std::optional<FunctionType> {
119 llvm::SmallVector<Type> newInputTypes;
120 auto check = convertTypes(funcType.getInputs(), newInputTypes);
121 if (check.failed()) {
122 llvm::errs() << "Failed to convert function input types\n";
123 return nullptr;
124 }
125
126 llvm::SmallVector<Type> newOutputTypes;
127 check = convertTypes(funcType.getResults(), newOutputTypes);
128 if (check.failed()) {
129 llvm::errs() << "Failed to convert function output types\n";
130 return nullptr;
131 }
132
133 return FunctionType::get(funcType.getContext(), newInputTypes,
134 newOutputTypes);
135 });
136
137 // Add conversions for other types as needed (llvm arrays?)
138 }
139};
140
142public:
143 BfpToIntegerConversionPattern(TypeConverter &typeConverter,
144 MLIRContext *context, bool &conversionFailed)
145 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context),
146 conversionFailed(conversionFailed) {}
147
148 LogicalResult
149 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
150 ConversionPatternRewriter &rewriter) const override {
151
152 // The objective is to replace all bfp operations by an integer of the
153 // appropriate width. This pass currently does not have any other
154 // functionality.
155
156 // Operation results
157 for (auto result : op->getResults()) {
158 auto conversion = typeConverter->convertType(result.getType());
159 if (!conversion) {
160 conversionFailed = true;
161 return op->emitError()
162 << "Failed to convert result type: " << result.getType();
163 }
164 result.setType(conversion);
165 }
166
167 // Operation operands
168 for (auto operand : op->getOperands()) {
169 auto conversion = typeConverter->convertType(operand.getType());
170 if (!conversion) {
171 conversionFailed = true;
172 return op->emitError()
173 << "Failed to convert operand type: " << operand.getType();
174 }
175 operand.setType(conversion);
176 }
177
178 // Operation attributes
179 // Note that the attribute list is immutable and
180 // needs to be recreated from scratch. Also note that type attributes cannot
181 // access their type and must therefore be managed through
182 // the convertTypeAttribute conversion instead
183 SmallVector<NamedAttribute> newAttrs;
184 for (auto attr : op->getAttrs()) {
185 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
186 auto conversion = typeConverter->convertTypeAttribute(
187 typeAttr.getValue(), attr.getValue());
188 if (!conversion) {
189 conversionFailed = true;
190 return op->emitError()
191 << "Failed to convert attribute type: " << typeAttr.getValue();
192 }
193 newAttrs.push_back(NamedAttribute(attr.getName(), conversion.value()));
194 } else {
195 newAttrs.push_back(attr);
196 }
197 }
198 op->setAttrs(DictionaryAttr::get(op->getContext(), newAttrs));
199
200 return success();
201 }
202
203private:
204 bool &conversionFailed;
205};
206
208 : public AIETransformBfpTypesBase<AIETransformBfpTypesPass> {
209public:
210 void runOnOperation() override {
211 DeviceOp device = getOperation();
212 MLIRContext *context = device.getContext();
213
214 BfpToIntegerConverter typeConverter(device.getTargetModel());
215
216 // Set up an empty conversion target, since we have to iterate over all ops
217 ConversionTarget target(*context);
218
219 RewritePatternSet patterns(context);
220 bool conversionFailed = false;
221 patterns.add<BfpToIntegerConversionPattern>(typeConverter, context,
222 conversionFailed);
223
224 if (failed(applyPartialConversion(device, target, std::move(patterns))) ||
225 conversionFailed) {
226 signalPassFailure();
227 }
228 }
229};
230
231std::unique_ptr<OperationPass<DeviceOp>>
233 return std::make_unique<AIETransformBfpTypesPass>();
234}
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const override
BfpToIntegerConversionPattern(TypeConverter &typeConverter, MLIRContext *context, bool &conversionFailed)
BfpToIntegerConverter(const AIETargetModel &targetModel)
virtual bool isSupportedBlockFormat(std::string const &format) const
std::unique_ptr< mlir::OperationPass< AIE::DeviceOp > > createAIETransformBfpTypesPass()
Include the generated interface declarations.