828 ConversionPatternRewriter &rewriter)
const override {
829 Location loc = op.getLoc();
831 Value result = op.getResult();
832 VectorType resultType = cast<VectorType>(result.getType());
834 Type resultScaTy = resultType.getElementType();
835 unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
837 int resultVectorSize = resultBitWidth * resultLanes;
839 Value opSrcVal = adaptor.getSource();
840 auto srcVecTy = cast<VectorType>(opSrcVal.getType());
842 if (srcVecTy != fltSrcVecTy)
845 .create<vector::ShapeCastOp>(op.getLoc(), fltSrcVecTy, opSrcVal)
850 Value upsIntrOp =
nullptr;
851 if (llvm::isa<IntegerType>(resultScaTy)) {
853 auto signCst = rewriter.create<LLVM::ConstantOp>(
854 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
855 auto shiftCst = rewriter.create<LLVM::ConstantOp>(
856 loc, rewriter.getI32Type(),
857 rewriter.getI32IntegerAttr(op.getShift()));
859 SmallVector<Value> operands({opSrcVal, shiftCst, signCst});
860 if (resultVectorSize == 512) {
861 if (resultBitWidth == 32) {
863 upsIntrOp = rewriter.create<xllvm::Acc32V16I256UpsIntrOp>(
864 loc, VectorType::get({8}, rewriter.getI64Type()),
865 forceCastOperandsToSignature(
866 rewriter, loc, operands,
867 {VectorType::get({16}, rewriter.getI16Type()),
868 rewriter.getI32Type(), rewriter.getI32Type()}));
869 }
else if (resultBitWidth == 64) {
871 upsIntrOp = rewriter.create<xllvm::Acc64V8I256UpsIntrOp>(
872 loc, VectorType::get({8}, rewriter.getI64Type()),
873 forceCastOperandsToSignature(
874 rewriter, loc, operands,
875 {VectorType::get({8}, rewriter.getI32Type()),
876 rewriter.getI32Type(), rewriter.getI32Type()}));
878 }
else if (resultVectorSize == 1024) {
879 Value src = opSrcVal;
880 VectorType srcType = cast<VectorType>(src.getType());
881 Type srcScaType = srcType.getElementType();
882 unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();
884 if (resultBitWidth == 32 && srcBitWidth == 16) {
886 upsIntrOp = rewriter.create<xllvm::Acc32V32I512UpsIntrOp>(
887 loc, VectorType::get({16}, rewriter.getI64Type()),
888 forceCastOperandsToSignature(
889 rewriter, loc, operands,
890 {VectorType::get({32}, rewriter.getI16Type()),
891 rewriter.getI32Type(), rewriter.getI32Type()}));
892 }
else if (resultBitWidth == 64 && srcBitWidth == 32) {
894 upsIntrOp = rewriter.create<xllvm::Acc64V16I512UpsIntrOp>(
895 loc, VectorType::get({16}, rewriter.getI64Type()),
896 forceCastOperandsToSignature(
897 rewriter, loc, operands,
898 {VectorType::get({16}, rewriter.getI32Type()),
899 rewriter.getI32Type(), rewriter.getI32Type()}));
900 }
else if (resultBitWidth == 64 && srcBitWidth == 16) {
902 upsIntrOp = rewriter.create<xllvm::Acc64V16I256UpsIntrOp>(
903 loc, VectorType::get({16}, rewriter.getI64Type()),
904 forceCastOperandsToSignature(
905 rewriter, loc, operands,
906 {VectorType::get({16}, rewriter.getI16Type()),
907 rewriter.getI32Type(), rewriter.getI32Type()}));
908 }
else if (resultBitWidth == 32 && srcBitWidth == 8) {
910 upsIntrOp = rewriter.create<xllvm::Acc32V32I256UpsIntrOp>(
911 loc, VectorType::get({16}, rewriter.getI64Type()),
912 forceCastOperandsToSignature(
913 rewriter, loc, operands,
914 {VectorType::get({32}, rewriter.getI8Type()),
915 rewriter.getI32Type(), rewriter.getI32Type()}));
920 if (resultVectorSize == 512) {
922 upsIntrOp = rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
923 loc, VectorType::get({8}, rewriter.getI64Type()),
924 forceCastOperandsToSignature(
925 rewriter, loc, {opSrcVal},
926 {VectorType::get({16}, rewriter.getBF16Type())}));
927 }
else if (resultVectorSize == 1024) {
935 auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
936 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
937 auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
938 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
939 auto extractUps = [&](Value source, Value index) -> Value {
940 auto extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
941 loc, VectorType::get({8}, rewriter.getI32Type()),
942 forceCastOperandsToSignature(
943 rewriter, loc, {source, index},
944 {VectorType::get({16}, rewriter.getI32Type()),
945 rewriter.getI32Type()}));
946 return rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
947 loc, VectorType::get({8}, rewriter.getI64Type()),
948 forceCastOperandsToSignature(
949 rewriter, loc, {extOp},
950 {VectorType::get({16}, rewriter.getBF16Type())}));
952 auto resLo = extractUps(opSrcVal, indexZeroCst);
953 auto resHi = extractUps(opSrcVal, indexOneCst);
956 upsIntrOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
957 loc, VectorType::get({32}, rewriter.getI32Type()),
958 forceCastOperandsToSignature(
959 rewriter, loc, {resLo, resHi},
960 {VectorType::get({16}, rewriter.getI32Type()),
961 VectorType::get({16}, rewriter.getI32Type())}));
966 op.emitWarning() <<
"aievec.ups is not supported.\n";
971 if (flatResTy != upsIntrOp.getType())
972 upsIntrOp = rewriter.create<LLVM::BitcastOp>(loc, flatResTy, upsIntrOp);
974 if (flatResTy != resultType)
976 rewriter.create<vector::ShapeCastOp>(loc, resultType, upsIntrOp);
978 rewriter.replaceOp(op, upsIntrOp);