827 ConversionPatternRewriter &rewriter)
const override {
828 Location loc = op.getLoc();
830 Value result = op.getResult();
831 VectorType resultType = cast<VectorType>(result.getType());
833 Type resultScaTy = resultType.getElementType();
834 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
836 int resultVectorSize = resultBitWidth * resultLanes;
838 Value opSrcVal = adaptor.getSource();
839 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
841 if (srcVecTy != fltSrcVecTy)
844 .create<vector::ShapeCastOp>(op.getLoc(), fltSrcVecTy, opSrcVal)
849 Value upsIntrOp =
nullptr;
850 if (llvm::isa<IntegerType>(resultScaTy)) {
852 auto signCst = rewriter.create<LLVM::ConstantOp>(
853 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
854 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
855 loc, rewriter.getI32Type(),
856 rewriter.getI32IntegerAttr(op.getShift()));
858 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
859 if (resultVectorSize == 512) {
860 if (resultBitWidth == 32) {
862 upsIntrOp = rewriter.create<xllvm::Acc32V16I256UpsIntrOp>(
863 loc, VectorType::get({8}, rewriter.getI64Type()),
864 forceCastOperandsToSignature(
865 rewriter, loc, operands,
866 {VectorType::get({16}, rewriter.getI16Type()),
867 rewriter.getI32Type(), rewriter.getI32Type()}));
868 }
else if (resultBitWidth == 64) {
870 upsIntrOp = rewriter.create<xllvm::Acc64V8I256UpsIntrOp>(
871 loc, VectorType::get({8}, rewriter.getI64Type()),
872 forceCastOperandsToSignature(
873 rewriter, loc, operands,
874 {VectorType::get({8}, rewriter.getI32Type()),
875 rewriter.getI32Type(), rewriter.getI32Type()}));
877 }
else if (resultVectorSize == 1024) {
878 Value src = opSrcVal;
879 VectorType srcType = cast<VectorType>(src.getType());
880 Type srcScaType = srcType.getElementType();
881 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
883 if (resultBitWidth == 32 && srcBitWidth == 16) {
885 upsIntrOp = rewriter.create<xllvm::Acc32V32I512UpsIntrOp>(
886 loc, VectorType::get({16}, rewriter.getI64Type()),
887 forceCastOperandsToSignature(
888 rewriter, loc, operands,
889 {VectorType::get({32}, rewriter.getI16Type()),
890 rewriter.getI32Type(), rewriter.getI32Type()}));
891 }
else if (resultBitWidth == 64 && srcBitWidth == 32) {
893 upsIntrOp = rewriter.create<xllvm::Acc64V16I512UpsIntrOp>(
894 loc, VectorType::get({16}, rewriter.getI64Type()),
895 forceCastOperandsToSignature(
896 rewriter, loc, operands,
897 {VectorType::get({16}, rewriter.getI32Type()),
898 rewriter.getI32Type(), rewriter.getI32Type()}));
899 }
else if (resultBitWidth == 64 && srcBitWidth == 16) {
901 upsIntrOp = rewriter.create<xllvm::Acc64V16I256UpsIntrOp>(
902 loc, VectorType::get({16}, rewriter.getI64Type()),
903 forceCastOperandsToSignature(
904 rewriter, loc, operands,
905 {VectorType::get({16}, rewriter.getI16Type()),
906 rewriter.getI32Type(), rewriter.getI32Type()}));
907 }
else if (resultBitWidth == 32 && srcBitWidth == 8) {
909 upsIntrOp = rewriter.create<xllvm::Acc32V32I256UpsIntrOp>(
910 loc, VectorType::get({16}, rewriter.getI64Type()),
911 forceCastOperandsToSignature(
912 rewriter, loc, operands,
913 {VectorType::get({32}, rewriter.getI8Type()),
914 rewriter.getI32Type(), rewriter.getI32Type()}));
919 if (resultVectorSize == 512) {
921 upsIntrOp = rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
922 loc, VectorType::get({8}, rewriter.getI64Type()),
923 forceCastOperandsToSignature(
924 rewriter, loc, {opSrcVal},
925 {VectorType::get({16}, rewriter.getBF16Type())}));
926 }
else if (resultVectorSize == 1024) {
934 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
935 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
936 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
937 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
938 auto extractUps = [&](Value source, Value index) -> Value {
939 auto extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
940 loc, VectorType::get({8}, rewriter.getI32Type()),
941 forceCastOperandsToSignature(
942 rewriter, loc, {source, index},
943 {VectorType::get({16}, rewriter.getI32Type()),
944 rewriter.getI32Type()}));
945 return rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
946 loc, VectorType::get({8}, rewriter.getI64Type()),
947 forceCastOperandsToSignature(
948 rewriter, loc, {extOp},
949 {VectorType::get({16}, rewriter.getBF16Type())}));
951 auto resLo = extractUps(opSrcVal, indexZeroCst);
952 auto resHi = extractUps(opSrcVal, indexOneCst);
955 upsIntrOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
956 loc, VectorType::get({32}, rewriter.getI32Type()),
957 forceCastOperandsToSignature(
958 rewriter, loc, {resLo, resHi},
959 {VectorType::get({16}, rewriter.getI32Type()),
960 VectorType::get({16}, rewriter.getI32Type())}));
965 op.emitWarning() <<
"aievec.ups is not supported.\n";
970 if (flatResTy != upsIntrOp.getType())
971 upsIntrOp = rewriter.create<LLVM::BitcastOp>(loc, flatResTy, upsIntrOp);
973 if (flatResTy != resultType)
975 rewriter.create<vector::ShapeCastOp>(loc, resultType, upsIntrOp);
977 rewriter.replaceOp(op, upsIntrOp);