Skip to content

API Reference

nearai

EntryLocation

Bases: BaseModel

EntryLocation

Source code in nearai/openapi_client/models/entry_location.py
class EntryLocation(BaseModel):
    """
    EntryLocation
    """ # noqa: E501
    namespace: StrictStr
    name: StrictStr
    version: StrictStr
    __properties: ClassVar[List[str]] = ["namespace", "name", "version"]

    model_config = ConfigDict(
        populate_by_name=True,
        validate_assignment=True,
        protected_namespaces=(),
    )


    def to_str(self) -> str:
        """Returns the string representation of the model using alias"""
        return pprint.pformat(self.model_dump(by_alias=True))

    def to_json(self) -> str:
        """Returns the JSON representation of the model using alias"""
        # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, json_str: str) -> Optional[Self]:
        """Create an instance of EntryLocation from a JSON string"""
        return cls.from_dict(json.loads(json_str))

    def to_dict(self) -> Dict[str, Any]:
        """Return the dictionary representation of the model using alias.

        This has the following differences from calling pydantic's
        `self.model_dump(by_alias=True)`:

        * `None` is only added to the output dict for nullable fields that
          were set at model initialization. Other fields with value `None`
          are ignored.
        """
        excluded_fields: Set[str] = set([
        ])

        _dict = self.model_dump(
            by_alias=True,
            exclude=excluded_fields,
            exclude_none=True,
        )
        return _dict

    @classmethod
    def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
        """Create an instance of EntryLocation from a dict"""
        if obj is None:
            return None

        if not isinstance(obj, dict):
            return cls.model_validate(obj)

        _obj = cls.model_validate({
            "namespace": obj.get("namespace"),
            "name": obj.get("name"),
            "version": obj.get("version")
        })
        return _obj

from_dict classmethod

from_dict(obj: Optional[Dict[str, Any]]) -> Optional[Self]

Create an instance of EntryLocation from a dict

Source code in nearai/openapi_client/models/entry_location.py
@classmethod
def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
    """Create an instance of EntryLocation from a dict"""
    if obj is None:
        return None

    if not isinstance(obj, dict):
        return cls.model_validate(obj)

    _obj = cls.model_validate({
        "namespace": obj.get("namespace"),
        "name": obj.get("name"),
        "version": obj.get("version")
    })
    return _obj

from_json classmethod

from_json(json_str: str) -> Optional[Self]

Create an instance of EntryLocation from a JSON string

Source code in nearai/openapi_client/models/entry_location.py
@classmethod
def from_json(cls, json_str: str) -> Optional[Self]:
    """Create an instance of EntryLocation from a JSON string"""
    return cls.from_dict(json.loads(json_str))

to_dict

to_dict() -> Dict[str, Any]

Return the dictionary representation of the model using alias.

This has the following differences from calling pydantic's self.model_dump(by_alias=True):

  • None is only added to the output dict for nullable fields that were set at model initialization. Other fields with value None are ignored.
Source code in nearai/openapi_client/models/entry_location.py
def to_dict(self) -> Dict[str, Any]:
    """Return the dictionary representation of the model using alias.

    This has the following differences from calling pydantic's
    `self.model_dump(by_alias=True)`:

    * `None` is only added to the output dict for nullable fields that
      were set at model initialization. Other fields with value `None`
      are ignored.
    """
    excluded_fields: Set[str] = set([
    ])

    _dict = self.model_dump(
        by_alias=True,
        exclude=excluded_fields,
        exclude_none=True,
    )
    return _dict

to_json

to_json() -> str

Returns the JSON representation of the model using alias

Source code in nearai/openapi_client/models/entry_location.py
def to_json(self) -> str:
    """Returns the JSON representation of the model using alias"""
    # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
    return json.dumps(self.to_dict())

to_str

to_str() -> str

Returns the string representation of the model using alias

Source code in nearai/openapi_client/models/entry_location.py
def to_str(self) -> str:
    """Returns the string representation of the model using alias"""
    return pprint.pformat(self.model_dump(by_alias=True))

parse_location

parse_location(entry_location: str) -> EntryLocation

Create a EntryLocation from a string in the format namespace/name/version.

Source code in nearai/lib.py
def parse_location(entry_location: str) -> EntryLocation:
    """Create a EntryLocation from a string in the format namespace/name/version."""
    match = entry_location_pattern.match(entry_location)

    if match is None:
        raise ValueError(f"Invalid entry format: {entry_location}. Should have the format <namespace>/<name>/<version>")

    return EntryLocation(
        namespace=match.group("namespace"),
        name=match.group("name"),
        version=match.group("version"),
    )

agents

agent

Agent

Bases: object

Source code in nearai/agents/agent.py
class Agent(object):
    def __init__(  # noqa: D107
        self, identifier: str, agent_files: Union[List, Path], metadata: Dict, change_to_temp_dir: bool = True
    ):  # noqa: D107
        self.identifier = identifier
        name_parts = identifier.split("/")
        self.namespace = name_parts[0]
        self.name = name_parts[1]
        self.version = name_parts[2]

        self.metadata = metadata
        self.env_vars: Dict[str, Any] = {}

        self.model = ""
        self.model_provider = ""
        self.model_temperature: Optional[float] = None
        self.model_max_tokens: Optional[int] = None
        self.max_iterations = 1
        self.welcome_title: Optional[str] = None
        self.welcome_description: Optional[str] = None

        self.set_agent_metadata(metadata)
        self.agent_files = agent_files
        self.original_cwd = os.getcwd()

        self.temp_dir = self.write_agent_files_to_temp(agent_files)
        self.change_to_temp_dir = change_to_temp_dir
        self.agent_filename = ""

    def get_full_name(self):
        """Returns full agent name."""
        return f"{self.namespace}/{self.name}/{self.version}"

    @staticmethod
    def write_agent_files_to_temp(agent_files):
        """Write agent files to a temporary directory."""
        unique_id = uuid.uuid4().hex
        temp_dir = os.path.join(tempfile.gettempdir(), f"agent_{unique_id}")

        if isinstance(agent_files, List):
            os.makedirs(temp_dir, exist_ok=True)

            for file_obj in agent_files:
                file_path = os.path.join(temp_dir, file_obj["filename"])

                try:
                    if not os.path.exists(os.path.dirname(file_path)):
                        os.makedirs(os.path.dirname(file_path))

                    content = file_obj["content"]

                    if isinstance(content, dict) or isinstance(content, list):
                        try:
                            content = json.dumps(content)
                        except Exception as e:
                            print(f"Error converting content to json: {e}")
                        content = str(content)

                    if isinstance(content, str):
                        content = content.encode("utf-8")

                    with open(file_path, "wb") as f:
                        with io.BytesIO(content) as byte_stream:
                            shutil.copyfileobj(byte_stream, f)
                except Exception as e:
                    print(f"Error writing file {file_path}: {e}")
                    raise e

        else:
            # if agent files is a PosixPath, it is a path to the agent directory
            # Copy all agent files including subfolders
            shutil.copytree(agent_files, temp_dir, dirs_exist_ok=True)

        return temp_dir

    def set_agent_metadata(self, metadata) -> None:
        """Set agent details from metadata."""
        try:
            self.name = metadata["name"]
            self.version = metadata["version"]
        except KeyError as e:
            raise ValueError(f"Missing key in metadata: {e}") from None

        details = metadata.get("details", {})
        agent = details.get("agent", {})
        welcome = agent.get("welcome", {})

        self.env_vars = details.get("env_vars", {})
        self.welcome_title = welcome.get("title")
        self.welcome_description = welcome.get("description")

        if agent_metadata := details.get("agent", None):
            if defaults := agent_metadata.get("defaults", None):
                self.model = defaults.get("model", self.model)
                self.model_provider = defaults.get("model_provider", self.model_provider)
                self.model_temperature = defaults.get("model_temperature", self.model_temperature)
                self.model_max_tokens = defaults.get("model_max_tokens", self.model_max_tokens)
                self.max_iterations = defaults.get("max_iterations", self.max_iterations)

        if not self.version or not self.name:
            raise ValueError("Both 'version' and 'name' must be non-empty in metadata.")

    def run(self, env: Any, task: Optional[str] = None) -> None:  # noqa: D102
        # combine agent.env_vars and env.env_vars
        total_env_vars = {**self.env_vars, **env.env_vars}

        # save os env vars
        os.environ.update(total_env_vars)
        # save env.env_vars
        env.env_vars = total_env_vars

        if not self.agent_filename:
            self.agent_filename = os.path.join(self.temp_dir, AGENT_FILENAME)
            if not os.path.exists(self.agent_filename):
                raise ValueError(f"Agent run error: {AGENT_FILENAME} does not exist")
            with open(self.agent_filename, "r") as f:
                self.code = compile(f.read(), self.agent_filename, "exec")
        else:
            print("Using cached agent code")

        namespace = {
            "env": env,
            "agent": self,
            "task": task,
            "__name__": "__main__",
            "__file__": self.agent_filename,
        }

        def run_agent_code(namespace):
            # switch to user env.agent_runner_user
            if env.agent_runner_user:
                user_info = pwd.getpwnam(env.agent_runner_user)
                os.setgid(user_info.pw_gid)
                os.setuid(user_info.pw_uid)

            # Run the code
            # NOTE: runpy.run_path does not work in a multithreaded environment when running benchmark.
            #       The performance of runpy.run_path may also change depending on a system, e.g. it may
            #       work on Linux but not work on Mac.
            #       `compile` and `exec` have been tested to work properly in a multithreaded environment.
            exec(self.code, namespace)

        try:
            if self.change_to_temp_dir:
                if not os.path.exists(self.temp_dir):
                    os.makedirs(self.temp_dir, exist_ok=True)
                os.chdir(self.temp_dir)
            sys.path.insert(0, self.temp_dir)

            if env.agent_runner_user:
                process = multiprocessing.Process(target=run_agent_code, args=(self.agent_filename, namespace))
                process.start()
                process.join()
            else:
                run_agent_code(namespace)
        finally:
            if os.path.exists(self.temp_dir):
                sys.path.remove(self.temp_dir)
            if self.change_to_temp_dir:
                os.chdir(self.original_cwd)

    @staticmethod
    def load_agents(agents: str, config: ClientConfig, local: bool = False):
        """Loads agents from the registry."""
        return [Agent.load_agent(agent, config, local) for agent in agents.split(",")]

    @staticmethod
    def load_agent(
        name: str,
        config: ClientConfig,
        local: bool = False,
    ):
        """Loads a single agent from the registry."""
        from nearai.registry import get_registry_folder, registry

        identifier = None
        if local:
            agent_files_path = get_registry_folder() / name
            if config.auth is None:
                namespace = "not-logged-in"
            else:
                namespace = config.auth.account_id
        else:
            agent_files_path = registry.download(name)
            identifier = name
        assert agent_files_path is not None, f"Agent {name} not found."

        metadata_path = os.path.join(agent_files_path, "metadata.json")
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
        with open(metadata_path) as f:
            metadata: Dict[str, Any] = json.load(f)

        if not identifier:
            identifier = "/".join([namespace, metadata["name"], metadata["version"]])

        return Agent(identifier, agent_files_path, metadata)
get_full_name
get_full_name()

Returns full agent name.

Source code in nearai/agents/agent.py
def get_full_name(self):
    """Returns full agent name."""
    return f"{self.namespace}/{self.name}/{self.version}"
load_agent staticmethod
load_agent(name: str, config: ClientConfig, local: bool = False)

Loads a single agent from the registry.

Source code in nearai/agents/agent.py
@staticmethod
def load_agent(
    name: str,
    config: ClientConfig,
    local: bool = False,
):
    """Loads a single agent from the registry."""
    from nearai.registry import get_registry_folder, registry

    identifier = None
    if local:
        agent_files_path = get_registry_folder() / name
        if config.auth is None:
            namespace = "not-logged-in"
        else:
            namespace = config.auth.account_id
    else:
        agent_files_path = registry.download(name)
        identifier = name
    assert agent_files_path is not None, f"Agent {name} not found."

    metadata_path = os.path.join(agent_files_path, "metadata.json")
    if not os.path.exists(metadata_path):
        raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
    with open(metadata_path) as f:
        metadata: Dict[str, Any] = json.load(f)

    if not identifier:
        identifier = "/".join([namespace, metadata["name"], metadata["version"]])

    return Agent(identifier, agent_files_path, metadata)
load_agents staticmethod
load_agents(agents: str, config: ClientConfig, local: bool = False)

Loads agents from the registry.

Source code in nearai/agents/agent.py
@staticmethod
def load_agents(agents: str, config: ClientConfig, local: bool = False):
    """Loads agents from the registry."""
    return [Agent.load_agent(agent, config, local) for agent in agents.split(",")]
set_agent_metadata
set_agent_metadata(metadata) -> None

Set agent details from metadata.

Source code in nearai/agents/agent.py
def set_agent_metadata(self, metadata) -> None:
    """Set agent details from metadata."""
    try:
        self.name = metadata["name"]
        self.version = metadata["version"]
    except KeyError as e:
        raise ValueError(f"Missing key in metadata: {e}") from None

    details = metadata.get("details", {})
    agent = details.get("agent", {})
    welcome = agent.get("welcome", {})

    self.env_vars = details.get("env_vars", {})
    self.welcome_title = welcome.get("title")
    self.welcome_description = welcome.get("description")

    if agent_metadata := details.get("agent", None):
        if defaults := agent_metadata.get("defaults", None):
            self.model = defaults.get("model", self.model)
            self.model_provider = defaults.get("model_provider", self.model_provider)
            self.model_temperature = defaults.get("model_temperature", self.model_temperature)
            self.model_max_tokens = defaults.get("model_max_tokens", self.model_max_tokens)
            self.max_iterations = defaults.get("max_iterations", self.max_iterations)

    if not self.version or not self.name:
        raise ValueError("Both 'version' and 'name' must be non-empty in metadata.")
write_agent_files_to_temp staticmethod
write_agent_files_to_temp(agent_files)

Write agent files to a temporary directory.

Source code in nearai/agents/agent.py
@staticmethod
def write_agent_files_to_temp(agent_files):
    """Write agent files to a temporary directory."""
    unique_id = uuid.uuid4().hex
    temp_dir = os.path.join(tempfile.gettempdir(), f"agent_{unique_id}")

    if isinstance(agent_files, List):
        os.makedirs(temp_dir, exist_ok=True)

        for file_obj in agent_files:
            file_path = os.path.join(temp_dir, file_obj["filename"])

            try:
                if not os.path.exists(os.path.dirname(file_path)):
                    os.makedirs(os.path.dirname(file_path))

                content = file_obj["content"]

                if isinstance(content, dict) or isinstance(content, list):
                    try:
                        content = json.dumps(content)
                    except Exception as e:
                        print(f"Error converting content to json: {e}")
                    content = str(content)

                if isinstance(content, str):
                    content = content.encode("utf-8")

                with open(file_path, "wb") as f:
                    with io.BytesIO(content) as byte_stream:
                        shutil.copyfileobj(byte_stream, f)
            except Exception as e:
                print(f"Error writing file {file_path}: {e}")
                raise e

    else:
        # if agent files is a PosixPath, it is a path to the agent directory
        # Copy all agent files including subfolders
        shutil.copytree(agent_files, temp_dir, dirs_exist_ok=True)

    return temp_dir

environment

Environment

Bases: object

Source code in nearai/agents/environment.py
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
class Environment(object):
    def __init__(  # noqa: D107
        self,
        path: str,
        agents: List[Agent],
        client: InferenceClient,
        hub_client: OpenAI,
        thread_id: str,
        run_id: str,
        create_files: bool = True,
        env_vars: Optional[Dict[str, Any]] = None,
        tool_resources: Optional[Dict[str, Any]] = None,
        print_system_log: bool = False,
        agent_runner_user: Optional[str] = None,
        approvals: Optional[Dict[str, Any]] = default_approvals,
    ) -> None:
        # Warning: never expose `client` or `_hub_client` to agent's environment

        # Placeholder for solver
        self.client: Optional[InferenceClient] = None

        self._path = path
        self._agents = agents
        self._done = False
        self._pending_ext_agent = False
        self.env_vars: Dict[str, Any] = env_vars if env_vars else {}
        self._last_used_model = ""
        self.tool_resources: Dict[str, Any] = tool_resources if tool_resources else {}
        self.print_system_log = print_system_log
        self.agent_runner_user = agent_runner_user
        self._approvals = approvals
        self._thread_id = thread_id
        self._run_id = run_id
        self._debug_mode = True if self.env_vars.get("DEBUG") else False

        self._tools = ToolRegistry()

        if create_files:
            os.makedirs(self._path, exist_ok=True)
            open(os.path.join(self._path, CHAT_FILENAME), "a").close()
        os.chdir(self._path)

        def signer_account_id() -> Optional[str]:
            """Expose the NEAR account_id of a user that signs this request to run an agent."""
            try:
                return client._config.auth.account_id if client._config.auth else None
            except (AttributeError, TypeError):
                return None

        self.signer_account_id = signer_account_id()

        # Client methods
        def query_vector_store(vector_store_id: str, query: str, full_files: bool = False):
            """Queries a vector store.

            vector_store_id: The id of the vector store to query.
            query: The query to search for.
            """
            return client.query_vector_store(vector_store_id, query, full_files)

        self.query_vector_store = query_vector_store

        def upload_file(
            file_content: str,
            purpose: Literal["assistants", "batch", "fine-tune", "vision"] = "assistants",
        ):
            """Uploads a file to the registry."""
            return client.upload_file(file_content, purpose)

        self.upload_file = upload_file

        def create_vector_store_from_source(
            name: str,
            source: Union[GitHubSource, GitLabSource],
            source_auth: Optional[str] = None,
            chunking_strategy: Optional[ChunkingStrategy] = None,
            expires_after: Optional[ExpiresAfter] = None,
            metadata: Optional[Dict[str, str]] = None,
        ) -> VectorStore:
            """Creates a vector store from the given source.

            Args:
            ----
                name: The name of the vector store.
                source: The source from which to create the vector store.
                source_auth: The source authentication token.
                chunking_strategy: The chunking strategy to use.
                expires_after: The expiration policy.
                metadata: Additional metadata.

            Returns:
            -------
                VectorStore: The created vector store.

            """
            return client.create_vector_store_from_source(
                name=name,
                source=source,
                source_auth=source_auth,
                chunking_strategy=chunking_strategy,
                expires_after=expires_after,
                metadata=metadata,
            )

        self.create_vector_store_from_source = create_vector_store_from_source

        def add_file_to_vector_store(vector_store_id: str, file_id: str):
            """Adds a file to the vector store."""
            return client.add_file_to_vector_store(vector_store_id, file_id)

        self.add_file_to_vector_store = add_file_to_vector_store

        def create_vector_store(
            name: str,
            file_ids: list,
            expires_after: Union[ExpiresAfter, NotGiven] = NOT_GIVEN,
            chunking_strategy: Union[
                AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, NotGiven
            ] = NOT_GIVEN,
            metadata: Optional[Dict[str, str]] = None,
        ) -> VectorStore:
            """Creates a vector store.

            Args:
            ----
                name: The name of the vector store.
                file_ids: List of file ids to create the vector store.
                chunking_strategy: The chunking strategy to use.
                expires_after: The expiration policy.
                metadata: Additional metadata.

            Returns:
            -------
                VectorStore: The created vector store.

            """
            return client.create_vector_store(
                name=name,
                file_ids=file_ids,
                chunking_strategy=chunking_strategy,
                expires_after=expires_after,
                metadata=metadata,
            )

        self.create_vector_store = create_vector_store

        def get_vector_store(self, vector_store_id: str) -> VectorStore:
            """Gets a vector store by id."""
            return client.get_vector_store(vector_store_id)

        self.get_vector_store = get_vector_store

        # Save cache of requested models for inference to avoid extra server calls
        self.cached_models_for_inference: Dict[str, str] = {}

        def get_model_for_inference(model: str = "") -> str:
            """Returns 'provider::model_full_path'."""
            if self.cached_models_for_inference.get(model, None) is None:
                provider = self._agents[0].model_provider if self._agents else ""
                if model == "":
                    model = self._agents[0].model if self._agents else ""
                if model == "":
                    return DEFAULT_PROVIDER_MODEL

                _, model_for_inference = client.provider_models.match_provider_model(model, provider)

                self.cached_models_for_inference[model] = model_for_inference

            return self.cached_models_for_inference[model]

        self.get_model_for_inference = get_model_for_inference

        def _run_inference_completions(
            messages: Union[Iterable[ChatCompletionMessageParam], str],
            model: Union[Iterable[ChatCompletionMessageParam], str],
            stream: bool,
            **kwargs: Any,
        ) -> Union[ModelResponse, CustomStreamWrapper]:
            """Run inference completions for given parameters."""
            params, kwargs = self.get_inference_parameters(messages, model, stream, **kwargs)

            completions = client.completions(
                params.model, params.messages, params.stream, params.temperature, params.max_tokens, **kwargs
            )

            return completions

        self._run_inference_completions = _run_inference_completions

        def get_agent_public_key():
            """Returns public key of the agent."""
            agent_name = self.get_primary_agent().get_full_name()

            return client.get_agent_public_key(agent_name)

        self.get_agent_public_key = get_agent_public_key

        def run_agent(
            owner: str,
            agent_name: str,
            version: str,
            model: Optional[str] = None,
            query: Optional[str] = None,
            fork_thread: bool = True,
        ):
            """Runs a child agent on the thread."""
            child_thread_id = self._thread_id
            if fork_thread:
                child_thread_id = client.threads_fork(self._thread_id).id
                self.add_system_log(f"Forked thread {child_thread_id}", logging.INFO)

            if query:
                client.threads_messages_create(thread_id=child_thread_id, content=query, role="user")

            assistant_id = f"{owner}/{agent_name}/{version}"
            model = model or DEFAULT_PROVIDER_MODEL
            self.add_system_log(f"Running agent {assistant_id}", logging.INFO)
            client.run_agent(
                current_run_id=self._run_id,
                child_thread_id=child_thread_id,
                assistant_id=assistant_id,
            )
            self._pending_ext_agent = True

            return child_thread_id

        self.run_agent = run_agent

        def schedule_run(
            agent: str,
            input_message: str,
            run_at: datetime,
            run_params: Optional[Dict[str, str]] = None,
            thread_id: Optional[str] = None,
        ):
            """Schedules a run."""
            return client.schedule_run(agent, input_message, thread_id, run_params, run_at)

        self.schedule_run = schedule_run

        # TODO(https://github.com/nearai/nearai/issues/549): Allow only a subset of agents to access/update user memory.
        def add_user_memory(memory: str):
            """Add user memory."""
            return client.add_user_memory(memory)

        self.add_user_memory = add_user_memory

        def query_user_memory(query: str):
            """Query user memory."""
            return client.query_user_memory(query)

        self.query_user_memory = query_user_memory

        def generate_image(prompt: str):
            """Generate an image."""
            return client.generate_image(prompt)

        self.generate_image = generate_image

        def save_agent_data(key, data: Dict[str, Any]):
            """Save agent data."""
            return client.save_agent_data(key, data)

        self.save_agent_data = save_agent_data

        def get_agent_data():
            """Get agent data."""
            return client.get_agent_data()

        self.get_agent_data = get_agent_data

        def get_agent_data_by_key(key, default=None):
            """Get agent data by key."""
            namespace = self._agents[0].namespace
            name = self._agents[0].name
            result = client.get_agent_data_by_key(key)
            return (
                result
                if result
                else {
                    "value": default,
                    "namespace": namespace,
                    "key": key,
                    "name": name,
                    "updated_at": "",
                    "created_at": "",
                }
            )

        self.get_agent_data_by_key = get_agent_data_by_key

        # HubClient methods
        def add_reply(
            message: str,
            attachments: Optional[Iterable[Attachment]] = None,
            message_type: Optional[str] = None,
        ):
            """Assistant adds a message to the environment."""
            # NOTE: message from `user` are not stored in the memory

            return hub_client.beta.threads.messages.create(
                thread_id=self._thread_id,
                role="assistant",
                content=message,
                extra_body={
                    "assistant_id": self._agents[0].identifier,
                    "run_id": self._run_id,
                },
                attachments=attachments,
                metadata={"message_type": message_type} if message_type else None,
            )

        self.add_reply = add_reply

        def _add_message(
            role: str,
            message: str,
            attachments: Optional[Iterable[Attachment]] = None,
            **kwargs: Any,
        ):
            return hub_client.beta.threads.messages.create(
                thread_id=self._thread_id,
                role=role,  # type: ignore
                content=message,
                extra_body={
                    "assistant_id": self._agents[0].identifier,
                    "run_id": self._run_id,
                },
                metadata=kwargs,
                attachments=attachments,
            )

        self._add_message = _add_message

        def _list_messages(
            limit: Union[int, NotGiven] = NOT_GIVEN,
            order: Literal["asc", "desc"] = "asc",
            thread_id: Optional[str] = None,
        ) -> List[Message]:
            """Returns messages from the environment."""
            messages = hub_client.beta.threads.messages.list(
                thread_id=thread_id or self._thread_id, limit=limit, order=order
            )
            self.add_system_log(f"Retrieved {len(messages.data)} messages from NEAR AI Hub")
            return messages.data

        self._list_messages = _list_messages

        def list_files_from_thread(
            order: Literal["asc", "desc"] = "asc", thread_id: Optional[str] = None
        ) -> List[FileObject]:
            """Lists files in the thread."""
            messages = self._list_messages(order=order)
            # Extract attachments from messages
            attachments = [a for m in messages if m.attachments for a in m.attachments]
            # Extract files from attachments
            file_ids = [a.file_id for a in attachments]
            files = [hub_client.files.retrieve(f) for f in file_ids if f]
            return files

        self.list_files_from_thread = list_files_from_thread

        def read_file_by_id(file_id: str):
            """Read a file from the thread."""
            content = hub_client.files.content(file_id).content.decode("utf-8")
            print("file content returned by api", content)
            return content

        self.read_file_by_id = read_file_by_id

        def write_file(
            filename: str,
            content: Union[str, bytes],
            encoding: str = "utf-8",
            filetype: str = "text/plain",
            write_to_disk: bool = True,
        ) -> FileObject:
            """Writes a file to the environment.

            filename: The name of the file to write to
            content: The content to write to the file
            encoding: The encoding to use when writing the file (default is utf-8)
            filetype: The MIME type of the file (default is text/plain)
            write_to_disk: If True, write locally to disk (default is True)
            """
            if write_to_disk:
                # Write locally
                path = Path(self.get_primary_agent_temp_dir()) / filename
                path.parent.mkdir(parents=True, exist_ok=True)
                if isinstance(content, bytes):
                    with open(path, "wb") as f:
                        f.write(content)
                else:
                    with open(path, "w", encoding=encoding) as f:
                        f.write(content)

            if isinstance(content, bytes):
                file_data = content
            else:
                file_data = io.BytesIO(content.encode(encoding))  # type:ignore

            # Upload to Hub
            file = hub_client.files.create(file=(filename, file_data, filetype), purpose="assistants")
            res = self.add_reply(
                message=f"Successfully wrote {len(content) if content else 0} characters to {filename}",
                attachments=[{"file_id": file.id, "tools": [{"type": "file_search"}]}],
                message_type="system:file_write",
            )
            self.add_system_log(
                f"Uploaded file {filename} with {len(content)} characters, id: {file.id}. Added in thread as: {res.id}"
            )
            return file

        self.write_file = write_file

        def mark_done() -> Run:  # noqa: D102
            self._done = True
            res = hub_client.beta.threads.runs.update(
                thread_id=self._thread_id,
                run_id=self._run_id,
                extra_body={
                    "status": "completed",
                    "completed_at": datetime.now().isoformat(),
                },
            )
            return res

        self.mark_done = mark_done

        def mark_failed() -> Run:
            """Marks the environment run as failed."""
            self._done = True
            self.add_system_log("Environment run failed", logging.ERROR)
            res = hub_client.beta.threads.runs.update(
                thread_id=self._thread_id,
                run_id=self._run_id,
                extra_body={"status": "failed", "failed_at": datetime.now().isoformat()},
            )
            return res

        self.mark_failed = mark_failed

        def request_user_input() -> Run:
            """Must be called to request input from the user."""
            return hub_client.beta.threads.runs.update(
                thread_id=self._thread_id,
                run_id=self._run_id,
                extra_body={"status": "requires_action"},
            )

        self.request_user_input = request_user_input

        # Must be placed after method definitions
        self.register_standard_tools()

    def get_tool_registry(self, new: bool = False) -> ToolRegistry:
        """Returns the tool registry, a dictionary of tools that can be called by the agent."""
        if new:
            self._tools = ToolRegistry()
        return self._tools

    def register_standard_tools(self) -> None:  # noqa: D102
        reg = self.get_tool_registry()
        reg.register_tool(self.exec_command)
        reg.register_tool(self.read_file)
        reg.register_tool(self.write_file)
        reg.register_tool(self.request_user_input)
        reg.register_tool(self.list_files)
        reg.register_tool(self.query_vector_store)

    def get_last_message(self, role: str = "user"):
        """Reads last message from the given role and returns it."""
        for message in reversed(self.list_messages()):
            if message.get("role") == role:
                return message

        return None

    def add_message(
        self,
        role: str,
        message: str,
        attachments: Optional[Iterable[Attachment]] = None,
        **kwargs: Any,
    ):
        """Deprecated. Please use `add_reply` instead. Assistant adds a message to the environment."""
        # Prevent agent to save messages on behalf of `user` to avoid adding false memory
        role = "assistant"

        return self._add_message(role, message, attachments, **kwargs)

    def add_system_log(self, log: str, level: int = logging.INFO) -> None:
        """Add system log with timestamp and log level."""
        logger = logging.getLogger("system_logger")
        if not logger.handlers:
            # Configure the logger if it hasn't been set up yet
            logger.setLevel(logging.DEBUG)
            file_handler = logging.FileHandler(os.path.join(self._path, SYSTEM_LOG_FILENAME))
            formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

            if self.print_system_log:
                console_handler = logging.StreamHandler()
                console_handler.setFormatter(formatter)
                logger.addHandler(console_handler)

            # Add Thread log handler
            if self._debug_mode:
                custom_handler = CustomLogHandler(self.add_reply, "system")
                custom_handler.setFormatter(formatter)
                logger.addHandler(custom_handler)

        # Log the message
        logger.log(level, log)
        # Force the handler to write to disk
        for handler in logger.handlers:
            handler.flush()

    def add_agent_log(self, log: str, level: int = logging.INFO) -> None:
        """Add agent log with timestamp and log level."""
        logger = logging.getLogger("agent_logger")
        if not logger.handlers:
            # Configure the logger if it hasn't been set up yet
            logger.setLevel(logging.DEBUG)
            file_handler = logging.FileHandler(os.path.join(self._path, AGENT_LOG_FILENAME))
            formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

            # Add Thread log handler
            if self._debug_mode:
                custom_handler = CustomLogHandler(self.add_reply, "agent")
                custom_handler.setFormatter(formatter)
                logger.addHandler(custom_handler)

        # Log the message
        logger.log(level, log)
        # Force the handler to write to disk
        for handler in logger.handlers:
            handler.flush()

    def add_agent_start_system_log(self, agent_idx: int) -> None:
        """Adds agent start system log."""
        agent = self._agents[agent_idx]
        message = f"Running agent {agent.name}"
        if agent.model != "":
            model = self.get_model_for_inference(agent.model)
            self._last_used_model = model
            message += f" that will connect to {model}"
            if agent.model_temperature:
                message += f", temperature={agent.model_temperature}"
            if agent.model_max_tokens:
                message += f", max_tokens={agent.model_max_tokens}"
        self.add_system_log(message)

    def list_terminal_commands(self, filename: str = TERMINAL_FILENAME) -> List[Any]:
        """Returns the terminal commands from the terminal file."""
        path = os.path.join(self._path, filename)

        if not os.path.exists(path):
            return []

        with open(path, "r") as f:
            return [json.loads(message) for message in f.read().split(DELIMITER) if message]

    def list_messages(self, thread_id: Optional[str] = None):
        """Backwards compatibility for chat_completions messages."""
        messages = self._list_messages(thread_id=thread_id)

        # Filter out system and agent log messages when running in debug mode. Agent behavior shouldn't change based on logs.  # noqa: E501
        if self._debug_mode:
            messages = [
                m
                for m in messages
                if not (m.metadata and m.metadata["message_type"] in ["system:log", "agent:log"])  # type: ignore
            ]
        legacy_messages = [
            {
                "id": m.id,
                "content": "\n".join([c.text.value for c in m.content]),  # type: ignore
                "role": m.role,
            }
            for m in messages
        ]
        return legacy_messages

    def verify_message(
        self,
        account_id: str,
        public_key: str,
        signature: str,
        message: str,
        nonce: str,
        callback_url: str,
    ) -> near.SignatureVerificationResult:
        """Verifies that the user message is signed with NEAR Account."""
        return near.verify_signed_message(
            account_id,
            public_key,
            signature,
            message,
            nonce,
            self._agents[0].name,
            callback_url,
        )

    def list_files(self, path: str, order: Literal["asc", "desc"] = "asc") -> List[str]:
        """Lists files in the environment."""
        return os.listdir(os.path.join(self.get_primary_agent_temp_dir(), path))

    def get_system_path(self) -> Path:
        """Returns the system path where chat.txt & system_log are stored."""
        return Path(self._path)

    def get_agent_temp_path(self) -> Path:
        """Returns temp dir for primary agent where execution happens."""
        return self.get_primary_agent_temp_dir()

    def read_file(self, filename: str):
        """Reads a file from the environment or thread."""
        file_content = None
        # First try to read from local filesystem
        local_path = os.path.join(self.get_primary_agent_temp_dir(), filename)
        if os.path.exists(local_path):
            with open(local_path, "r") as local_file:
                file_content = local_file.read()

        thread_files = self.list_files_from_thread(order="desc")

        # Then try to read from thread, starting from the most recent
        for f in thread_files:
            if f.filename == filename:
                file_content = self.read_file_by_id(f.id)
                break

        # Write the file content to the local filesystem
        if file_content:
            with open(local_path, "w") as local_file:
                local_file.write(file_content)
        else:
            self.add_system_log(f"Warn: File {filename} not found during read_file operation")

        return file_content

    def exec_command(self, command: str) -> Dict[str, Union[str, int]]:
        """Executes a command in the environment and logs the output.

        The environment does not allow running interactive programs.
        It will run a program for 1 second then will interrupt it if it is still running
        or if it is waiting for user input.
        command: The command to execute, like 'ls -l' or 'python3 tests.py'
        """
        approval_function = self._approvals["confirm_execution"] if self._approvals else None
        if not approval_function:
            return {
                "stderr": "Agent runner misconfiguration. No command execution approval function found.",
            }
        if not approval_function(command):
            return {
                "command": command,
                "returncode": 999,
                "stdout": "",
                "stderr": "Command execution was not approved.",
            }

        try:
            process = subprocess.Popen(
                shlex.split(command),
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                bufsize=0,
                universal_newlines=True,
                cwd=self._path,
            )
        except Exception as e:
            return {
                "command": command,
                "returncode": 999,
                "stdout": "",
                "stderr": "Failed to execute: " + str(e),
            }

        msg = ""

        def kill_process_tree(p: Any) -> None:
            nonlocal msg
            msg = "Killing process due to timeout"

            process = psutil.Process(p.pid)
            for proc in process.children(recursive=True):
                proc.kill()
            process.kill()

        timer = threading.Timer(2, kill_process_tree, (process,))
        timer.start()
        process.wait()
        timer.cancel()

        result = {
            "command": command,
            "stdout": process.stdout.read() if process.stdout and hasattr(process.stdout, "read") else "",
            "stderr": process.stderr.read() if process.stderr and hasattr(process.stderr, "read") else "",
            "returncode": process.returncode,
            "msg": msg,
        }
        with open(os.path.join(self._path, TERMINAL_FILENAME), "a") as f:
            f.write(json.dumps(result) + DELIMITER)
        return result

    def get_inference_parameters(
        self,
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        model: Union[Iterable[ChatCompletionMessageParam], str],
        stream: bool,
        **kwargs: Any,
    ) -> Tuple[InferenceParameters, Any]:
        """Run inference parameters to run completions."""
        if isinstance(messages, str):
            self.add_system_log(
                "Deprecated completions call. Pass `messages` as a first parameter.",
                logging.WARNING,
            )
            messages_or_model = messages
            model_or_messages = model
            model = cast(str, messages_or_model)
            messages = cast(Iterable[ChatCompletionMessageParam], model_or_messages)
        else:
            model = cast(str, model)
            messages = cast(Iterable[ChatCompletionMessageParam], messages)
        model = self.get_model_for_inference(model)
        if model != self._last_used_model:
            self._last_used_model = model
            self.add_system_log(f"Connecting to {model}")

        temperature = kwargs.pop("temperature", self._agents[0].model_temperature if self._agents else None)
        max_tokens = kwargs.pop("max_tokens", self._agents[0].model_max_tokens if self._agents else None)

        params = InferenceParameters(
            model=model,
            messages=messages,
            stream=stream,
            temperature=temperature,
            max_tokens=max_tokens,
        )

        return params, kwargs

    # TODO(286): `messages` may be model and `model` may be messages temporarily to support deprecated API.
    def completions(
        self,
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        model: Union[Iterable[ChatCompletionMessageParam], str] = "",
        stream: bool = False,
        **kwargs: Any,
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        """Returns all completions for given messages using the given model."""
        return self._run_inference_completions(messages, model, stream, **kwargs)

    def verify_signed_message(
        self,
        completion: str,
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        public_key: Union[str, None] = None,
        signature: Union[str, None] = None,
        model: Union[Iterable[ChatCompletionMessageParam], str] = "",
        **kwargs: Any,
    ) -> bool:
        """Verifies a signed message."""
        if public_key is None or signature is None:
            return False

        params, _ = self.get_inference_parameters(messages, model, False, **kwargs)

        messages_without_ids = [{k: v for k, v in item.items() if k != "id"} for item in params.messages]
        ordered_messages_without_ids = [
            {"role": str(item["role"]), "content": str(item["content"])} for item in messages_without_ids
        ]

        return validate_completion_signature(
            public_key,
            signature,
            CompletionSignaturePayload(
                agent_name=self.get_primary_agent().get_full_name(),
                completion=completion,
                model=params.model,
                messages=ordered_messages_without_ids,
                temperature=params.temperature,
                max_tokens=params.max_tokens,
            ),
        )

    def completions_and_run_tools(
        self,
        messages: List[ChatCompletionMessageParam],
        model: str = "",
        tools: Optional[List] = None,
        add_responses_to_messages: bool = True,
        agent_role_name="assistant",
        tool_role_name="tool",
        **kwargs: Any,
    ) -> ModelResponse:
        """Returns all completions for given messages using the given model and runs tools."""
        if self._use_llama_tool_syntax(model, tools):
            tool_prompt = self._llama_tool_prompt(tools)
            messages.append({"role": "system", "content": tool_prompt})
        raw_response = self._run_inference_completions(messages, model, stream=False, tools=tools, **kwargs)
        assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
        response: ModelResponse = raw_response
        assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
        choices: List[Choices] = response.choices  # type: ignore
        response_message = choices[0].message

        self._handle_tool_calls(response_message, add_responses_to_messages, agent_role_name, tool_role_name)

        return response

    def _handle_tool_calls(
        self,
        response_message,
        add_responses_to_messages,
        agent_role_name,
        tool_role_name,
    ):
        (message_without_tool_call, tool_calls) = self._parse_tool_call(response_message)
        if add_responses_to_messages and response_message.content:
            self.add_message(agent_role_name, message_without_tool_call)
        if tool_calls:
            for tool_call in tool_calls:
                function_name = tool_call.function.name
                try:
                    assert function_name, "Tool call must have a function name"
                    function_signature = self.get_tool_registry().get_tool_definition(function_name)
                    assert function_signature, f"Tool {function_name} not found"
                    args = tool_call.function.arguments
                    function_args = tool_json_helper.parse_json_args(function_signature, args)
                    self.add_system_log(f"Calling tool {function_name} with args {function_args}")
                    function_response = self._tools.call_tool(function_name, **function_args if function_args else {})

                    if function_response:
                        function_response_json = json.dumps(function_response) if function_response else ""
                        if add_responses_to_messages:
                            self.add_message(
                                tool_role_name,
                                function_response_json,
                                tool_call_id=tool_call.id,
                                name=function_name,
                            )
                except Exception as e:
                    error_message = f"Error calling tool {function_name}: {e}"
                    self.add_system_log(error_message, level=logging.ERROR)
                    if add_responses_to_messages:
                        self.add_message(
                            tool_role_name,
                            error_message,
                            tool_call_id=tool_call.id,
                            name=function_name,
                        )

    @staticmethod
    def _parse_tool_call(
        response_message,
    ) -> Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]:
        if hasattr(response_message, "tool_calls") and response_message.tool_calls:
            return response_message.content, response_message.tool_calls
        content = response_message.content
        if content is None:
            return None, None
        llama_matches = LLAMA_TOOL_FORMAT_PATTERN.findall(content)
        if llama_matches:
            text = ""
            tool_calls = []
            for llama_match in llama_matches:
                before_call_text, function_name, args, end_tag, after_call_text = llama_match
                function = Function(name=function_name, arguments=args)
                tool_call = ChatCompletionMessageToolCall(id=str(uuid.uuid4()), function=function)
                text += before_call_text + after_call_text
                tool_calls.append(tool_call)
            return text, tool_calls

        llama_matches = LLAMA_TOOL_FORMAT_PATTERN2.findall(content)
        if llama_matches:
            text = ""
            tool_calls = []
            for llama_match in llama_matches:
                before_call_text, function_name_and_args, after_call_text = llama_match
                try:
                    parsed_function_name_and_args = json.loads(function_name_and_args)
                    function_name = parsed_function_name_and_args.get("name")
                    args = parsed_function_name_and_args.get("arguments")
                    function = Function(name=function_name, arguments=args)
                    tool_call = ChatCompletionMessageToolCall(id=str(uuid.uuid4()), function=function)
                    text += before_call_text + after_call_text
                    tool_calls.append(tool_call)
                except json.JSONDecodeError:
                    print(f"Error parsing tool_call function name and args: {function_name_and_args}")
                    continue
            return text, tool_calls

        return content, None

    @staticmethod
    def _use_llama_tool_syntax(model: str, tools: Optional[List]) -> bool:
        return tools is not None and "llama" in model

    @staticmethod
    def _llama_tool_prompt(tools: Optional[List]) -> str:
        return (
            """Answer the user's question by making use of the following functions if needed.
            If none of the function can be used, please say so.
            Here is a list of functions in JSON format:"""
            + json.dumps(tools)
            + """Think very carefully before calling functions.
            If you choose to call a function ONLY reply in the following format with no prefix or suffix:

            <function=example_function_name>{"example_name": "example_value"}</function>

            Reminder:
            - Function calls MUST follow the specified format, start with <function= and end with </function>
            - Function arguments MUST be in JSON format using double quotes
            - Required parameters MUST be specified
            - Multiple functions can be called in one message as long as they are on separate lines.
            - Put the entire function call reply on one line
        """
        )

    # TODO(286): `messages` may be model and `model` may be messages temporarily to support deprecated API.
    def completion(
        self,
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        model: Union[Iterable[ChatCompletionMessageParam], str] = "",
        **kwargs: Any,
    ) -> str:
        """Returns a completion for the given messages using the given model."""
        raw_response = self.completions(messages, model, **kwargs)
        assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
        response: ModelResponse = raw_response
        assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
        choices: List[Choices] = response.choices  # type: ignore
        response_message = choices[0].message.content
        assert response_message, "No completions returned"
        return response_message

    def signed_completion(
        self,
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        model: Union[Iterable[ChatCompletionMessageParam], str] = "",
        **kwargs: Any,
    ) -> Dict[str, str]:
        """Returns a completion for the given messages using the given model with the agent signature."""
        # TODO Return signed completions for non-latest versions only?
        agent_name = self.get_primary_agent().get_full_name()
        raw_response = self.completions(messages, model, agent_name=agent_name, **kwargs)
        assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
        response: ModelResponse = raw_response

        signature_data = json.loads(response.system_fingerprint) if response.system_fingerprint else {}

        assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
        choices: List[Choices] = response.choices  # type: ignore
        response_message = choices[0].message.content
        assert response_message, "No completions returned"

        return {
            "response": response_message,
            "signature": signature_data.get("signature", None),
            "public_key": signature_data.get("public_key", None),
        }

    def completion_and_run_tools(
        self,
        messages: List[ChatCompletionMessageParam],
        model: str = "",
        tools: Optional[List] = None,
        **kwargs: Any,
    ) -> Optional[str]:
        """Returns a completion for the given messages using the given model and runs tools."""
        completion_tools_response = self.completions_and_run_tools(messages, model, tools, **kwargs)
        assert all(
            map(
                lambda choice: isinstance(choice, Choices),
                completion_tools_response.choices,
            )
        ), "Expected Choices"
        choices: List[Choices] = completion_tools_response.choices  # type: ignore
        response_content = choices[0].message.content
        return response_content

    def call_agent(self, agent_index: int, task: str) -> None:
        """Calls agent with given task."""
        self._agents[agent_index].run(self, task=task)

    def get_agents(self) -> List[Agent]:
        """Returns list of agents available in environment."""
        return self._agents

    def get_primary_agent(self) -> Agent:
        """Returns the agent that is invoked first."""
        return self._agents[0]

    def get_primary_agent_temp_dir(self) -> Path:
        """Returns temp dir for primary agent."""
        return self._agents[0].temp_dir

    def is_done(self) -> bool:  # noqa: D102
        return self._done

    def create_snapshot(self) -> bytes:
        """Create an in memory snapshot."""
        with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
            with tarfile.open(fileobj=f, mode="w:gz") as tar:
                tar.add(self._path, arcname=".")
            f.flush()
            f.seek(0)
            snapshot = f.read()
        return snapshot

    def environment_run_info(self, base_id, run_type) -> dict:
        """Returns the environment run information."""
        if not self._agents or not self._agents[0]:
            raise ValueError("Agent not found")
        primary_agent = self._agents[0]

        full_agent_name = "/".join([primary_agent.namespace, primary_agent.name, primary_agent.version])
        safe_agent_name = full_agent_name.replace("/", "_")
        uid = uuid.uuid4().hex
        generated_name = f"environment_run_{safe_agent_name}_{uid}"
        name = generated_name

        timestamp = datetime.now(timezone.utc).isoformat()
        return {
            "name": name,
            "version": "0",
            "description": f"Agent {run_type} {full_agent_name} {uid} {timestamp}",
            "category": "environment",
            "tags": ["environment"],
            "details": {
                "base_id": base_id,
                "timestamp": timestamp,
                "agents": [agent.name for agent in self._agents],
                "primary_agent_namespace": primary_agent.namespace,
                "primary_agent_name": primary_agent.name,
                "primary_agent_version": primary_agent.version,
                "run_id": self._run_id,
                "run_type": run_type,
            },
            "show_entry": True,
        }

    def load_snapshot(self, snapshot: bytes) -> None:
        """Load Environment from Snapshot."""
        shutil.rmtree(self._path, ignore_errors=True)

        with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
            f.write(snapshot)
            f.flush()
            f.seek(0)

            with tarfile.open(fileobj=f, mode="r:gz") as tar:
                tar.extractall(self._path)

    def __str__(self) -> str:  # noqa: D105
        return f"Environment({self._path})"

    def clear_temp_agent_files(self, verbose=True) -> None:
        """Remove temp agent files created to be used in `runpy`."""
        for agent in self._agents:
            if os.path.exists(agent.temp_dir):
                if verbose:
                    print("removed agent.temp_files", agent.temp_dir)
                shutil.rmtree(agent.temp_dir)

    def set_next_actor(self, who: str) -> None:
        """Set the next actor / action in the dialogue."""
        next_action_fn = os.path.join(self._path, ".next_action")
        if who == "agent":
            self._done = False

        with open(next_action_fn, "w") as f:
            f.write(who)

    def get_next_actor(self) -> str:  # noqa: D102
        next_action_fn = os.path.join(self._path, ".next_action")

        if os.path.exists(next_action_fn):
            with open(next_action_fn) as f:
                return f.read().strip(" \n")
        else:
            # By default the user starts the conversation.
            return "user"

    def run(
        self,
        new_message: Optional[str] = None,
        max_iterations: int = 10,
    ) -> None:
        """Runs agent(s) against a new or previously created environment."""
        if new_message:
            self._add_message("user", new_message)

        iteration = 0
        self.set_next_actor("agent")

        while iteration < max_iterations and not self.is_done() and self.get_next_actor() != "user":
            iteration += 1
            if max_iterations > 1:
                self.add_system_log(
                    f"Running agent, iteration {iteration}/{max_iterations}",
                    logging.INFO,
                )
            try:
                self._agents[0].run(self, task=new_message)
            except Exception as e:
                self.add_system_log(f"Environment run failed: {e}", logging.ERROR)
                self.mark_failed()
                raise e

        if not self._pending_ext_agent:
            # If no external agent was called, mark the whole run as done.
            # Else this environment will stop for now but this run will be continued later.
            self.mark_done()

    def generate_folder_hash_id(self, path: str) -> str:
        """Returns hash based on files and their contents in path, including subfolders."""  # noqa: E501
        hash_obj = hashlib.md5()

        for root, _dirs, files in os.walk(path):
            for file in sorted(files):
                file_path = os.path.join(root, file)
                with open(file_path, "rb") as f:
                    while chunk := f.read(8192):
                        hash_obj.update(chunk)

        return hash_obj.hexdigest()
__init__
__init__(path: str, agents: List[Agent], client: InferenceClient, hub_client: OpenAI, thread_id: str, run_id: str, create_files: bool = True, env_vars: Optional[Dict[str, Any]] = None, tool_resources: Optional[Dict[str, Any]] = None, print_system_log: bool = False, agent_runner_user: Optional[str] = None, approvals: Optional[Dict[str, Any]] = default_approvals) -> None
Source code in nearai/agents/environment.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
def __init__(  # noqa: D107
    self,
    path: str,
    agents: List[Agent],
    client: InferenceClient,
    hub_client: OpenAI,
    thread_id: str,
    run_id: str,
    create_files: bool = True,
    env_vars: Optional[Dict[str, Any]] = None,
    tool_resources: Optional[Dict[str, Any]] = None,
    print_system_log: bool = False,
    agent_runner_user: Optional[str] = None,
    approvals: Optional[Dict[str, Any]] = default_approvals,
) -> None:
    # Warning: never expose `client` or `_hub_client` to agent's environment

    # Placeholder for solver
    self.client: Optional[InferenceClient] = None

    self._path = path
    self._agents = agents
    self._done = False
    self._pending_ext_agent = False
    self.env_vars: Dict[str, Any] = env_vars if env_vars else {}
    self._last_used_model = ""
    self.tool_resources: Dict[str, Any] = tool_resources if tool_resources else {}
    self.print_system_log = print_system_log
    self.agent_runner_user = agent_runner_user
    self._approvals = approvals
    self._thread_id = thread_id
    self._run_id = run_id
    self._debug_mode = True if self.env_vars.get("DEBUG") else False

    self._tools = ToolRegistry()

    if create_files:
        os.makedirs(self._path, exist_ok=True)
        open(os.path.join(self._path, CHAT_FILENAME), "a").close()
    os.chdir(self._path)

    def signer_account_id() -> Optional[str]:
        """Expose the NEAR account_id of a user that signs this request to run an agent."""
        try:
            return client._config.auth.account_id if client._config.auth else None
        except (AttributeError, TypeError):
            return None

    self.signer_account_id = signer_account_id()

    # Client methods
    def query_vector_store(vector_store_id: str, query: str, full_files: bool = False):
        """Queries a vector store.

        vector_store_id: The id of the vector store to query.
        query: The query to search for.
        """
        return client.query_vector_store(vector_store_id, query, full_files)

    self.query_vector_store = query_vector_store

    def upload_file(
        file_content: str,
        purpose: Literal["assistants", "batch", "fine-tune", "vision"] = "assistants",
    ):
        """Uploads a file to the registry."""
        return client.upload_file(file_content, purpose)

    self.upload_file = upload_file

    def create_vector_store_from_source(
        name: str,
        source: Union[GitHubSource, GitLabSource],
        source_auth: Optional[str] = None,
        chunking_strategy: Optional[ChunkingStrategy] = None,
        expires_after: Optional[ExpiresAfter] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> VectorStore:
        """Creates a vector store from the given source.

        Args:
        ----
            name: The name of the vector store.
            source: The source from which to create the vector store.
            source_auth: The source authentication token.
            chunking_strategy: The chunking strategy to use.
            expires_after: The expiration policy.
            metadata: Additional metadata.

        Returns:
        -------
            VectorStore: The created vector store.

        """
        return client.create_vector_store_from_source(
            name=name,
            source=source,
            source_auth=source_auth,
            chunking_strategy=chunking_strategy,
            expires_after=expires_after,
            metadata=metadata,
        )

    self.create_vector_store_from_source = create_vector_store_from_source

    def add_file_to_vector_store(vector_store_id: str, file_id: str):
        """Adds a file to the vector store."""
        return client.add_file_to_vector_store(vector_store_id, file_id)

    self.add_file_to_vector_store = add_file_to_vector_store

    def create_vector_store(
        name: str,
        file_ids: list,
        expires_after: Union[ExpiresAfter, NotGiven] = NOT_GIVEN,
        chunking_strategy: Union[
            AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, NotGiven
        ] = NOT_GIVEN,
        metadata: Optional[Dict[str, str]] = None,
    ) -> VectorStore:
        """Creates a vector store.

        Args:
        ----
            name: The name of the vector store.
            file_ids: List of file ids to create the vector store.
            chunking_strategy: The chunking strategy to use.
            expires_after: The expiration policy.
            metadata: Additional metadata.

        Returns:
        -------
            VectorStore: The created vector store.

        """
        return client.create_vector_store(
            name=name,
            file_ids=file_ids,
            chunking_strategy=chunking_strategy,
            expires_after=expires_after,
            metadata=metadata,
        )

    self.create_vector_store = create_vector_store

    def get_vector_store(self, vector_store_id: str) -> VectorStore:
        """Gets a vector store by id."""
        return client.get_vector_store(vector_store_id)

    self.get_vector_store = get_vector_store

    # Save cache of requested models for inference to avoid extra server calls
    self.cached_models_for_inference: Dict[str, str] = {}

    def get_model_for_inference(model: str = "") -> str:
        """Returns 'provider::model_full_path'."""
        if self.cached_models_for_inference.get(model, None) is None:
            provider = self._agents[0].model_provider if self._agents else ""
            if model == "":
                model = self._agents[0].model if self._agents else ""
            if model == "":
                return DEFAULT_PROVIDER_MODEL

            _, model_for_inference = client.provider_models.match_provider_model(model, provider)

            self.cached_models_for_inference[model] = model_for_inference

        return self.cached_models_for_inference[model]

    self.get_model_for_inference = get_model_for_inference

    def _run_inference_completions(
        messages: Union[Iterable[ChatCompletionMessageParam], str],
        model: Union[Iterable[ChatCompletionMessageParam], str],
        stream: bool,
        **kwargs: Any,
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        """Run inference completions for given parameters."""
        params, kwargs = self.get_inference_parameters(messages, model, stream, **kwargs)

        completions = client.completions(
            params.model, params.messages, params.stream, params.temperature, params.max_tokens, **kwargs
        )

        return completions

    self._run_inference_completions = _run_inference_completions

    def get_agent_public_key():
        """Returns public key of the agent."""
        agent_name = self.get_primary_agent().get_full_name()

        return client.get_agent_public_key(agent_name)

    self.get_agent_public_key = get_agent_public_key

    def run_agent(
        owner: str,
        agent_name: str,
        version: str,
        model: Optional[str] = None,
        query: Optional[str] = None,
        fork_thread: bool = True,
    ):
        """Runs a child agent on the thread."""
        child_thread_id = self._thread_id
        if fork_thread:
            child_thread_id = client.threads_fork(self._thread_id).id
            self.add_system_log(f"Forked thread {child_thread_id}", logging.INFO)

        if query:
            client.threads_messages_create(thread_id=child_thread_id, content=query, role="user")

        assistant_id = f"{owner}/{agent_name}/{version}"
        model = model or DEFAULT_PROVIDER_MODEL
        self.add_system_log(f"Running agent {assistant_id}", logging.INFO)
        client.run_agent(
            current_run_id=self._run_id,
            child_thread_id=child_thread_id,
            assistant_id=assistant_id,
        )
        self._pending_ext_agent = True

        return child_thread_id

    self.run_agent = run_agent

    def schedule_run(
        agent: str,
        input_message: str,
        run_at: datetime,
        run_params: Optional[Dict[str, str]] = None,
        thread_id: Optional[str] = None,
    ):
        """Schedules a run."""
        return client.schedule_run(agent, input_message, thread_id, run_params, run_at)

    self.schedule_run = schedule_run

    # TODO(https://github.com/nearai/nearai/issues/549): Allow only a subset of agents to access/update user memory.
    def add_user_memory(memory: str):
        """Add user memory."""
        return client.add_user_memory(memory)

    self.add_user_memory = add_user_memory

    def query_user_memory(query: str):
        """Query user memory."""
        return client.query_user_memory(query)

    self.query_user_memory = query_user_memory

    def generate_image(prompt: str):
        """Generate an image."""
        return client.generate_image(prompt)

    self.generate_image = generate_image

    def save_agent_data(key, data: Dict[str, Any]):
        """Save agent data."""
        return client.save_agent_data(key, data)

    self.save_agent_data = save_agent_data

    def get_agent_data():
        """Get agent data."""
        return client.get_agent_data()

    self.get_agent_data = get_agent_data

    def get_agent_data_by_key(key, default=None):
        """Get agent data by key."""
        namespace = self._agents[0].namespace
        name = self._agents[0].name
        result = client.get_agent_data_by_key(key)
        return (
            result
            if result
            else {
                "value": default,
                "namespace": namespace,
                "key": key,
                "name": name,
                "updated_at": "",
                "created_at": "",
            }
        )

    self.get_agent_data_by_key = get_agent_data_by_key

    # HubClient methods
    def add_reply(
        message: str,
        attachments: Optional[Iterable[Attachment]] = None,
        message_type: Optional[str] = None,
    ):
        """Assistant adds a message to the environment."""
        # NOTE: message from `user` are not stored in the memory

        return hub_client.beta.threads.messages.create(
            thread_id=self._thread_id,
            role="assistant",
            content=message,
            extra_body={
                "assistant_id": self._agents[0].identifier,
                "run_id": self._run_id,
            },
            attachments=attachments,
            metadata={"message_type": message_type} if message_type else None,
        )

    self.add_reply = add_reply

    def _add_message(
        role: str,
        message: str,
        attachments: Optional[Iterable[Attachment]] = None,
        **kwargs: Any,
    ):
        return hub_client.beta.threads.messages.create(
            thread_id=self._thread_id,
            role=role,  # type: ignore
            content=message,
            extra_body={
                "assistant_id": self._agents[0].identifier,
                "run_id": self._run_id,
            },
            metadata=kwargs,
            attachments=attachments,
        )

    self._add_message = _add_message

    def _list_messages(
        limit: Union[int, NotGiven] = NOT_GIVEN,
        order: Literal["asc", "desc"] = "asc",
        thread_id: Optional[str] = None,
    ) -> List[Message]:
        """Returns messages from the environment."""
        messages = hub_client.beta.threads.messages.list(
            thread_id=thread_id or self._thread_id, limit=limit, order=order
        )
        self.add_system_log(f"Retrieved {len(messages.data)} messages from NEAR AI Hub")
        return messages.data

    self._list_messages = _list_messages

    def list_files_from_thread(
        order: Literal["asc", "desc"] = "asc", thread_id: Optional[str] = None
    ) -> List[FileObject]:
        """Lists files in the thread."""
        messages = self._list_messages(order=order)
        # Extract attachments from messages
        attachments = [a for m in messages if m.attachments for a in m.attachments]
        # Extract files from attachments
        file_ids = [a.file_id for a in attachments]
        files = [hub_client.files.retrieve(f) for f in file_ids if f]
        return files

    self.list_files_from_thread = list_files_from_thread

    def read_file_by_id(file_id: str):
        """Read a file from the thread."""
        content = hub_client.files.content(file_id).content.decode("utf-8")
        print("file content returned by api", content)
        return content

    self.read_file_by_id = read_file_by_id

    def write_file(
        filename: str,
        content: Union[str, bytes],
        encoding: str = "utf-8",
        filetype: str = "text/plain",
        write_to_disk: bool = True,
    ) -> FileObject:
        """Writes a file to the environment.

        filename: The name of the file to write to
        content: The content to write to the file
        encoding: The encoding to use when writing the file (default is utf-8)
        filetype: The MIME type of the file (default is text/plain)
        write_to_disk: If True, write locally to disk (default is True)
        """
        if write_to_disk:
            # Write locally
            path = Path(self.get_primary_agent_temp_dir()) / filename
            path.parent.mkdir(parents=True, exist_ok=True)
            if isinstance(content, bytes):
                with open(path, "wb") as f:
                    f.write(content)
            else:
                with open(path, "w", encoding=encoding) as f:
                    f.write(content)

        if isinstance(content, bytes):
            file_data = content
        else:
            file_data = io.BytesIO(content.encode(encoding))  # type:ignore

        # Upload to Hub
        file = hub_client.files.create(file=(filename, file_data, filetype), purpose="assistants")
        res = self.add_reply(
            message=f"Successfully wrote {len(content) if content else 0} characters to {filename}",
            attachments=[{"file_id": file.id, "tools": [{"type": "file_search"}]}],
            message_type="system:file_write",
        )
        self.add_system_log(
            f"Uploaded file {filename} with {len(content)} characters, id: {file.id}. Added in thread as: {res.id}"
        )
        return file

    self.write_file = write_file

    def mark_done() -> Run:  # noqa: D102
        self._done = True
        res = hub_client.beta.threads.runs.update(
            thread_id=self._thread_id,
            run_id=self._run_id,
            extra_body={
                "status": "completed",
                "completed_at": datetime.now().isoformat(),
            },
        )
        return res

    self.mark_done = mark_done

    def mark_failed() -> Run:
        """Marks the environment run as failed."""
        self._done = True
        self.add_system_log("Environment run failed", logging.ERROR)
        res = hub_client.beta.threads.runs.update(
            thread_id=self._thread_id,
            run_id=self._run_id,
            extra_body={"status": "failed", "failed_at": datetime.now().isoformat()},
        )
        return res

    self.mark_failed = mark_failed

    def request_user_input() -> Run:
        """Must be called to request input from the user."""
        return hub_client.beta.threads.runs.update(
            thread_id=self._thread_id,
            run_id=self._run_id,
            extra_body={"status": "requires_action"},
        )

    self.request_user_input = request_user_input

    # Must be placed after method definitions
    self.register_standard_tools()
add_agent_log
add_agent_log(log: str, level: int = logging.INFO) -> None

Add agent log with timestamp and log level.

Source code in nearai/agents/environment.py
def add_agent_log(self, log: str, level: int = logging.INFO) -> None:
    """Add agent log with timestamp and log level."""
    logger = logging.getLogger("agent_logger")
    if not logger.handlers:
        # Configure the logger if it hasn't been set up yet
        logger.setLevel(logging.DEBUG)
        file_handler = logging.FileHandler(os.path.join(self._path, AGENT_LOG_FILENAME))
        formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        # Add Thread log handler
        if self._debug_mode:
            custom_handler = CustomLogHandler(self.add_reply, "agent")
            custom_handler.setFormatter(formatter)
            logger.addHandler(custom_handler)

    # Log the message
    logger.log(level, log)
    # Force the handler to write to disk
    for handler in logger.handlers:
        handler.flush()
add_agent_start_system_log
add_agent_start_system_log(agent_idx: int) -> None

Adds agent start system log.

Source code in nearai/agents/environment.py
def add_agent_start_system_log(self, agent_idx: int) -> None:
    """Adds agent start system log."""
    agent = self._agents[agent_idx]
    message = f"Running agent {agent.name}"
    if agent.model != "":
        model = self.get_model_for_inference(agent.model)
        self._last_used_model = model
        message += f" that will connect to {model}"
        if agent.model_temperature:
            message += f", temperature={agent.model_temperature}"
        if agent.model_max_tokens:
            message += f", max_tokens={agent.model_max_tokens}"
    self.add_system_log(message)
add_message
add_message(role: str, message: str, attachments: Optional[Iterable[Attachment]] = None, **kwargs: Any)

Deprecated. Please use add_reply instead. Assistant adds a message to the environment.

Source code in nearai/agents/environment.py
def add_message(
    self,
    role: str,
    message: str,
    attachments: Optional[Iterable[Attachment]] = None,
    **kwargs: Any,
):
    """Deprecated. Please use `add_reply` instead. Assistant adds a message to the environment."""
    # Prevent agent to save messages on behalf of `user` to avoid adding false memory
    role = "assistant"

    return self._add_message(role, message, attachments, **kwargs)
add_system_log
add_system_log(log: str, level: int = logging.INFO) -> None

Add system log with timestamp and log level.

Source code in nearai/agents/environment.py
def add_system_log(self, log: str, level: int = logging.INFO) -> None:
    """Add system log with timestamp and log level."""
    logger = logging.getLogger("system_logger")
    if not logger.handlers:
        # Configure the logger if it hasn't been set up yet
        logger.setLevel(logging.DEBUG)
        file_handler = logging.FileHandler(os.path.join(self._path, SYSTEM_LOG_FILENAME))
        formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        if self.print_system_log:
            console_handler = logging.StreamHandler()
            console_handler.setFormatter(formatter)
            logger.addHandler(console_handler)

        # Add Thread log handler
        if self._debug_mode:
            custom_handler = CustomLogHandler(self.add_reply, "system")
            custom_handler.setFormatter(formatter)
            logger.addHandler(custom_handler)

    # Log the message
    logger.log(level, log)
    # Force the handler to write to disk
    for handler in logger.handlers:
        handler.flush()
call_agent
call_agent(agent_index: int, task: str) -> None

Calls agent with given task.

Source code in nearai/agents/environment.py
def call_agent(self, agent_index: int, task: str) -> None:
    """Calls agent with given task."""
    self._agents[agent_index].run(self, task=task)
clear_temp_agent_files
clear_temp_agent_files(verbose=True) -> None

Remove temp agent files created to be used in runpy.

Source code in nearai/agents/environment.py
def clear_temp_agent_files(self, verbose=True) -> None:
    """Remove temp agent files created to be used in `runpy`."""
    for agent in self._agents:
        if os.path.exists(agent.temp_dir):
            if verbose:
                print("removed agent.temp_files", agent.temp_dir)
            shutil.rmtree(agent.temp_dir)
completion
completion(messages: Union[Iterable[ChatCompletionMessageParam], str], model: Union[Iterable[ChatCompletionMessageParam], str] = '', **kwargs: Any) -> str

Returns a completion for the given messages using the given model.

Source code in nearai/agents/environment.py
def completion(
    self,
    messages: Union[Iterable[ChatCompletionMessageParam], str],
    model: Union[Iterable[ChatCompletionMessageParam], str] = "",
    **kwargs: Any,
) -> str:
    """Returns a completion for the given messages using the given model."""
    raw_response = self.completions(messages, model, **kwargs)
    assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
    response: ModelResponse = raw_response
    assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
    choices: List[Choices] = response.choices  # type: ignore
    response_message = choices[0].message.content
    assert response_message, "No completions returned"
    return response_message
completion_and_run_tools
completion_and_run_tools(messages: List[ChatCompletionMessageParam], model: str = '', tools: Optional[List] = None, **kwargs: Any) -> Optional[str]

Returns a completion for the given messages using the given model and runs tools.

Source code in nearai/agents/environment.py
def completion_and_run_tools(
    self,
    messages: List[ChatCompletionMessageParam],
    model: str = "",
    tools: Optional[List] = None,
    **kwargs: Any,
) -> Optional[str]:
    """Returns a completion for the given messages using the given model and runs tools."""
    completion_tools_response = self.completions_and_run_tools(messages, model, tools, **kwargs)
    assert all(
        map(
            lambda choice: isinstance(choice, Choices),
            completion_tools_response.choices,
        )
    ), "Expected Choices"
    choices: List[Choices] = completion_tools_response.choices  # type: ignore
    response_content = choices[0].message.content
    return response_content
completions
completions(messages: Union[Iterable[ChatCompletionMessageParam], str], model: Union[Iterable[ChatCompletionMessageParam], str] = '', stream: bool = False, **kwargs: Any) -> Union[ModelResponse, CustomStreamWrapper]

Returns all completions for given messages using the given model.

Source code in nearai/agents/environment.py
def completions(
    self,
    messages: Union[Iterable[ChatCompletionMessageParam], str],
    model: Union[Iterable[ChatCompletionMessageParam], str] = "",
    stream: bool = False,
    **kwargs: Any,
) -> Union[ModelResponse, CustomStreamWrapper]:
    """Returns all completions for given messages using the given model."""
    return self._run_inference_completions(messages, model, stream, **kwargs)
completions_and_run_tools
completions_and_run_tools(messages: List[ChatCompletionMessageParam], model: str = '', tools: Optional[List] = None, add_responses_to_messages: bool = True, agent_role_name='assistant', tool_role_name='tool', **kwargs: Any) -> ModelResponse

Returns all completions for given messages using the given model and runs tools.

Source code in nearai/agents/environment.py
def completions_and_run_tools(
    self,
    messages: List[ChatCompletionMessageParam],
    model: str = "",
    tools: Optional[List] = None,
    add_responses_to_messages: bool = True,
    agent_role_name="assistant",
    tool_role_name="tool",
    **kwargs: Any,
) -> ModelResponse:
    """Returns all completions for given messages using the given model and runs tools."""
    if self._use_llama_tool_syntax(model, tools):
        tool_prompt = self._llama_tool_prompt(tools)
        messages.append({"role": "system", "content": tool_prompt})
    raw_response = self._run_inference_completions(messages, model, stream=False, tools=tools, **kwargs)
    assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
    response: ModelResponse = raw_response
    assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
    choices: List[Choices] = response.choices  # type: ignore
    response_message = choices[0].message

    self._handle_tool_calls(response_message, add_responses_to_messages, agent_role_name, tool_role_name)

    return response
create_snapshot
create_snapshot() -> bytes

Create an in memory snapshot.

Source code in nearai/agents/environment.py
def create_snapshot(self) -> bytes:
    """Create an in memory snapshot."""
    with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
        with tarfile.open(fileobj=f, mode="w:gz") as tar:
            tar.add(self._path, arcname=".")
        f.flush()
        f.seek(0)
        snapshot = f.read()
    return snapshot
environment_run_info
environment_run_info(base_id, run_type) -> dict

Returns the environment run information.

Source code in nearai/agents/environment.py
def environment_run_info(self, base_id, run_type) -> dict:
    """Returns the environment run information."""
    if not self._agents or not self._agents[0]:
        raise ValueError("Agent not found")
    primary_agent = self._agents[0]

    full_agent_name = "/".join([primary_agent.namespace, primary_agent.name, primary_agent.version])
    safe_agent_name = full_agent_name.replace("/", "_")
    uid = uuid.uuid4().hex
    generated_name = f"environment_run_{safe_agent_name}_{uid}"
    name = generated_name

    timestamp = datetime.now(timezone.utc).isoformat()
    return {
        "name": name,
        "version": "0",
        "description": f"Agent {run_type} {full_agent_name} {uid} {timestamp}",
        "category": "environment",
        "tags": ["environment"],
        "details": {
            "base_id": base_id,
            "timestamp": timestamp,
            "agents": [agent.name for agent in self._agents],
            "primary_agent_namespace": primary_agent.namespace,
            "primary_agent_name": primary_agent.name,
            "primary_agent_version": primary_agent.version,
            "run_id": self._run_id,
            "run_type": run_type,
        },
        "show_entry": True,
    }
exec_command
exec_command(command: str) -> Dict[str, Union[str, int]]

Executes a command in the environment and logs the output.

The environment does not allow running interactive programs. It will run a program for 1 second then will interrupt it if it is still running or if it is waiting for user input. command: The command to execute, like 'ls -l' or 'python3 tests.py'

Source code in nearai/agents/environment.py
def exec_command(self, command: str) -> Dict[str, Union[str, int]]:
    """Executes a command in the environment and logs the output.

    The environment does not allow running interactive programs.
    It will run a program for 1 second then will interrupt it if it is still running
    or if it is waiting for user input.
    command: The command to execute, like 'ls -l' or 'python3 tests.py'
    """
    approval_function = self._approvals["confirm_execution"] if self._approvals else None
    if not approval_function:
        return {
            "stderr": "Agent runner misconfiguration. No command execution approval function found.",
        }
    if not approval_function(command):
        return {
            "command": command,
            "returncode": 999,
            "stdout": "",
            "stderr": "Command execution was not approved.",
        }

    try:
        process = subprocess.Popen(
            shlex.split(command),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            bufsize=0,
            universal_newlines=True,
            cwd=self._path,
        )
    except Exception as e:
        return {
            "command": command,
            "returncode": 999,
            "stdout": "",
            "stderr": "Failed to execute: " + str(e),
        }

    msg = ""

    def kill_process_tree(p: Any) -> None:
        nonlocal msg
        msg = "Killing process due to timeout"

        process = psutil.Process(p.pid)
        for proc in process.children(recursive=True):
            proc.kill()
        process.kill()

    timer = threading.Timer(2, kill_process_tree, (process,))
    timer.start()
    process.wait()
    timer.cancel()

    result = {
        "command": command,
        "stdout": process.stdout.read() if process.stdout and hasattr(process.stdout, "read") else "",
        "stderr": process.stderr.read() if process.stderr and hasattr(process.stderr, "read") else "",
        "returncode": process.returncode,
        "msg": msg,
    }
    with open(os.path.join(self._path, TERMINAL_FILENAME), "a") as f:
        f.write(json.dumps(result) + DELIMITER)
    return result
generate_folder_hash_id
generate_folder_hash_id(path: str) -> str

Returns hash based on files and their contents in path, including subfolders.

Source code in nearai/agents/environment.py
def generate_folder_hash_id(self, path: str) -> str:
    """Returns hash based on files and their contents in path, including subfolders."""  # noqa: E501
    hash_obj = hashlib.md5()

    for root, _dirs, files in os.walk(path):
        for file in sorted(files):
            file_path = os.path.join(root, file)
            with open(file_path, "rb") as f:
                while chunk := f.read(8192):
                    hash_obj.update(chunk)

    return hash_obj.hexdigest()
get_agent_temp_path
get_agent_temp_path() -> Path

Returns temp dir for primary agent where execution happens.

Source code in nearai/agents/environment.py
def get_agent_temp_path(self) -> Path:
    """Returns temp dir for primary agent where execution happens."""
    return self.get_primary_agent_temp_dir()
get_agents
get_agents() -> List[Agent]

Returns list of agents available in environment.

Source code in nearai/agents/environment.py
def get_agents(self) -> List[Agent]:
    """Returns list of agents available in environment."""
    return self._agents
get_inference_parameters
get_inference_parameters(messages: Union[Iterable[ChatCompletionMessageParam], str], model: Union[Iterable[ChatCompletionMessageParam], str], stream: bool, **kwargs: Any) -> Tuple[InferenceParameters, Any]

Run inference parameters to run completions.

Source code in nearai/agents/environment.py
def get_inference_parameters(
    self,
    messages: Union[Iterable[ChatCompletionMessageParam], str],
    model: Union[Iterable[ChatCompletionMessageParam], str],
    stream: bool,
    **kwargs: Any,
) -> Tuple[InferenceParameters, Any]:
    """Run inference parameters to run completions."""
    if isinstance(messages, str):
        self.add_system_log(
            "Deprecated completions call. Pass `messages` as a first parameter.",
            logging.WARNING,
        )
        messages_or_model = messages
        model_or_messages = model
        model = cast(str, messages_or_model)
        messages = cast(Iterable[ChatCompletionMessageParam], model_or_messages)
    else:
        model = cast(str, model)
        messages = cast(Iterable[ChatCompletionMessageParam], messages)
    model = self.get_model_for_inference(model)
    if model != self._last_used_model:
        self._last_used_model = model
        self.add_system_log(f"Connecting to {model}")

    temperature = kwargs.pop("temperature", self._agents[0].model_temperature if self._agents else None)
    max_tokens = kwargs.pop("max_tokens", self._agents[0].model_max_tokens if self._agents else None)

    params = InferenceParameters(
        model=model,
        messages=messages,
        stream=stream,
        temperature=temperature,
        max_tokens=max_tokens,
    )

    return params, kwargs
get_last_message
get_last_message(role: str = 'user')

Reads last message from the given role and returns it.

Source code in nearai/agents/environment.py
def get_last_message(self, role: str = "user"):
    """Reads last message from the given role and returns it."""
    for message in reversed(self.list_messages()):
        if message.get("role") == role:
            return message

    return None
get_primary_agent
get_primary_agent() -> Agent

Returns the agent that is invoked first.

Source code in nearai/agents/environment.py
def get_primary_agent(self) -> Agent:
    """Returns the agent that is invoked first."""
    return self._agents[0]
get_primary_agent_temp_dir
get_primary_agent_temp_dir() -> Path

Returns temp dir for primary agent.

Source code in nearai/agents/environment.py
def get_primary_agent_temp_dir(self) -> Path:
    """Returns temp dir for primary agent."""
    return self._agents[0].temp_dir
get_system_path
get_system_path() -> Path

Returns the system path where chat.txt & system_log are stored.

Source code in nearai/agents/environment.py
def get_system_path(self) -> Path:
    """Returns the system path where chat.txt & system_log are stored."""
    return Path(self._path)
get_tool_registry
get_tool_registry(new: bool = False) -> ToolRegistry

Returns the tool registry, a dictionary of tools that can be called by the agent.

Source code in nearai/agents/environment.py
def get_tool_registry(self, new: bool = False) -> ToolRegistry:
    """Returns the tool registry, a dictionary of tools that can be called by the agent."""
    if new:
        self._tools = ToolRegistry()
    return self._tools
list_files
list_files(path: str, order: Literal['asc', 'desc'] = 'asc') -> List[str]

Lists files in the environment.

Source code in nearai/agents/environment.py
def list_files(self, path: str, order: Literal["asc", "desc"] = "asc") -> List[str]:
    """Lists files in the environment."""
    return os.listdir(os.path.join(self.get_primary_agent_temp_dir(), path))
list_messages
list_messages(thread_id: Optional[str] = None)

Backwards compatibility for chat_completions messages.

Source code in nearai/agents/environment.py
def list_messages(self, thread_id: Optional[str] = None):
    """Backwards compatibility for chat_completions messages."""
    messages = self._list_messages(thread_id=thread_id)

    # Filter out system and agent log messages when running in debug mode. Agent behavior shouldn't change based on logs.  # noqa: E501
    if self._debug_mode:
        messages = [
            m
            for m in messages
            if not (m.metadata and m.metadata["message_type"] in ["system:log", "agent:log"])  # type: ignore
        ]
    legacy_messages = [
        {
            "id": m.id,
            "content": "\n".join([c.text.value for c in m.content]),  # type: ignore
            "role": m.role,
        }
        for m in messages
    ]
    return legacy_messages
list_terminal_commands
list_terminal_commands(filename: str = TERMINAL_FILENAME) -> List[Any]

Returns the terminal commands from the terminal file.

Source code in nearai/agents/environment.py
def list_terminal_commands(self, filename: str = TERMINAL_FILENAME) -> List[Any]:
    """Returns the terminal commands from the terminal file."""
    path = os.path.join(self._path, filename)

    if not os.path.exists(path):
        return []

    with open(path, "r") as f:
        return [json.loads(message) for message in f.read().split(DELIMITER) if message]
load_snapshot
load_snapshot(snapshot: bytes) -> None

Load Environment from Snapshot.

Source code in nearai/agents/environment.py
def load_snapshot(self, snapshot: bytes) -> None:
    """Load Environment from Snapshot."""
    shutil.rmtree(self._path, ignore_errors=True)

    with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
        f.write(snapshot)
        f.flush()
        f.seek(0)

        with tarfile.open(fileobj=f, mode="r:gz") as tar:
            tar.extractall(self._path)
read_file
read_file(filename: str)

Reads a file from the environment or thread.

Source code in nearai/agents/environment.py
def read_file(self, filename: str):
    """Reads a file from the environment or thread."""
    file_content = None
    # First try to read from local filesystem
    local_path = os.path.join(self.get_primary_agent_temp_dir(), filename)
    if os.path.exists(local_path):
        with open(local_path, "r") as local_file:
            file_content = local_file.read()

    thread_files = self.list_files_from_thread(order="desc")

    # Then try to read from thread, starting from the most recent
    for f in thread_files:
        if f.filename == filename:
            file_content = self.read_file_by_id(f.id)
            break

    # Write the file content to the local filesystem
    if file_content:
        with open(local_path, "w") as local_file:
            local_file.write(file_content)
    else:
        self.add_system_log(f"Warn: File {filename} not found during read_file operation")

    return file_content
run
run(new_message: Optional[str] = None, max_iterations: int = 10) -> None

Runs agent(s) against a new or previously created environment.

Source code in nearai/agents/environment.py
def run(
    self,
    new_message: Optional[str] = None,
    max_iterations: int = 10,
) -> None:
    """Runs agent(s) against a new or previously created environment."""
    if new_message:
        self._add_message("user", new_message)

    iteration = 0
    self.set_next_actor("agent")

    while iteration < max_iterations and not self.is_done() and self.get_next_actor() != "user":
        iteration += 1
        if max_iterations > 1:
            self.add_system_log(
                f"Running agent, iteration {iteration}/{max_iterations}",
                logging.INFO,
            )
        try:
            self._agents[0].run(self, task=new_message)
        except Exception as e:
            self.add_system_log(f"Environment run failed: {e}", logging.ERROR)
            self.mark_failed()
            raise e

    if not self._pending_ext_agent:
        # If no external agent was called, mark the whole run as done.
        # Else this environment will stop for now but this run will be continued later.
        self.mark_done()
set_next_actor
set_next_actor(who: str) -> None

Set the next actor / action in the dialogue.

Source code in nearai/agents/environment.py
def set_next_actor(self, who: str) -> None:
    """Set the next actor / action in the dialogue."""
    next_action_fn = os.path.join(self._path, ".next_action")
    if who == "agent":
        self._done = False

    with open(next_action_fn, "w") as f:
        f.write(who)
signed_completion
signed_completion(messages: Union[Iterable[ChatCompletionMessageParam], str], model: Union[Iterable[ChatCompletionMessageParam], str] = '', **kwargs: Any) -> Dict[str, str]

Returns a completion for the given messages using the given model with the agent signature.

Source code in nearai/agents/environment.py
def signed_completion(
    self,
    messages: Union[Iterable[ChatCompletionMessageParam], str],
    model: Union[Iterable[ChatCompletionMessageParam], str] = "",
    **kwargs: Any,
) -> Dict[str, str]:
    """Returns a completion for the given messages using the given model with the agent signature."""
    # TODO Return signed completions for non-latest versions only?
    agent_name = self.get_primary_agent().get_full_name()
    raw_response = self.completions(messages, model, agent_name=agent_name, **kwargs)
    assert isinstance(raw_response, ModelResponse), "Expected ModelResponse"
    response: ModelResponse = raw_response

    signature_data = json.loads(response.system_fingerprint) if response.system_fingerprint else {}

    assert all(map(lambda choice: isinstance(choice, Choices), response.choices)), "Expected Choices"
    choices: List[Choices] = response.choices  # type: ignore
    response_message = choices[0].message.content
    assert response_message, "No completions returned"

    return {
        "response": response_message,
        "signature": signature_data.get("signature", None),
        "public_key": signature_data.get("public_key", None),
    }
verify_message
verify_message(account_id: str, public_key: str, signature: str, message: str, nonce: str, callback_url: str) -> SignatureVerificationResult

Verifies that the user message is signed with NEAR Account.

Source code in nearai/agents/environment.py
def verify_message(
    self,
    account_id: str,
    public_key: str,
    signature: str,
    message: str,
    nonce: str,
    callback_url: str,
) -> near.SignatureVerificationResult:
    """Verifies that the user message is signed with NEAR Account."""
    return near.verify_signed_message(
        account_id,
        public_key,
        signature,
        message,
        nonce,
        self._agents[0].name,
        callback_url,
    )
verify_signed_message
verify_signed_message(completion: str, messages: Union[Iterable[ChatCompletionMessageParam], str], public_key: Union[str, None] = None, signature: Union[str, None] = None, model: Union[Iterable[ChatCompletionMessageParam], str] = '', **kwargs: Any) -> bool

Verifies a signed message.

Source code in nearai/agents/environment.py
def verify_signed_message(
    self,
    completion: str,
    messages: Union[Iterable[ChatCompletionMessageParam], str],
    public_key: Union[str, None] = None,
    signature: Union[str, None] = None,
    model: Union[Iterable[ChatCompletionMessageParam], str] = "",
    **kwargs: Any,
) -> bool:
    """Verifies a signed message."""
    if public_key is None or signature is None:
        return False

    params, _ = self.get_inference_parameters(messages, model, False, **kwargs)

    messages_without_ids = [{k: v for k, v in item.items() if k != "id"} for item in params.messages]
    ordered_messages_without_ids = [
        {"role": str(item["role"]), "content": str(item["content"])} for item in messages_without_ids
    ]

    return validate_completion_signature(
        public_key,
        signature,
        CompletionSignaturePayload(
            agent_name=self.get_primary_agent().get_full_name(),
            completion=completion,
            model=params.model,
            messages=ordered_messages_without_ids,
            temperature=params.temperature,
            max_tokens=params.max_tokens,
        ),
    )

tool_json_helper

parse_json_args
parse_json_args(signature: dict, args: str)

Parses LLM generated JSON args, trying various repair strategies if args are not valid JSON.

Source code in nearai/agents/tool_json_helper.py
def parse_json_args(signature: dict, args: str):
    """Parses LLM generated JSON args, trying various repair strategies if args are not valid JSON."""
    # if args is empty or an empty json object check if the function has no arguments
    if not args or args == "{}":
        if not signature["function"]["parameters"]["required"]:
            return {}
        else:
            raise ValueError("Function requires arguments")

    transforms = [
        lambda x: json.loads(x),
        _ending_transform,
        lambda x: parse_json_args_based_on_signature(signature, x),
    ]

    for transform in transforms:
        try:
            result = transform(args)
            # check that all result keys are valid properties in the signature
            for key in result.keys():
                if key not in signature["function"]["parameters"]["properties"]:
                    raise json.JSONDecodeError(f"Unknown parameter {key}", args, 0)
            return result
        except json.JSONDecodeError:
            continue
        except Exception as err:
            raise json.JSONDecodeError("Error parsing function args", args, 0) from err
parse_json_args_based_on_signature
parse_json_args_based_on_signature(signature: dict, args: str)

Finds parameter names based on the signature and tries to extract the values in between from the args string.

Source code in nearai/agents/tool_json_helper.py
def parse_json_args_based_on_signature(signature: dict, args: str):
    """Finds parameter names based on the signature and tries to extract the values in between from the args string."""
    parameter_names = list(signature["function"]["parameters"]["properties"].keys())
    # find each parameter name in the args string
    #   assuming each parameter name is surrounded by "s, followed by a colon and optionally preceded by a comma,
    #   extract the intervening values as values
    parameter_positions = {}
    parameter_values = {}
    for param in parameter_names:
        match = re.search(f',?\\s*"({param})"\\s*:', args)
        if not match:
            raise ValueError(f"Parameter {param} not found in args {args}")
        parameter_positions[param] = (match.start(), match.end())
    # sort the parameter positions by start position
    sorted_positions = sorted(parameter_positions.items(), key=lambda x: x[1][0])
    # for each parameter, extract the value from the args string
    for i, (param, (start, end)) in enumerate(sorted_positions):  # noqa B007
        # if this is the last parameter, extract the value from the start position to the end of the string
        if i == len(sorted_positions) - 1:
            raw_value = args[end:-1]
            if raw_value.endswith("}"):
                raw_value = raw_value[:-1]
        # otherwise, extract the value from the start position to the start position of the next parameter
        else:
            next_start = sorted_positions[i + 1][1][0]
            raw_value = args[end:next_start]
        raw_value = raw_value.strip()
        if raw_value.startswith('"') and raw_value.endswith('"'):
            raw_value = raw_value[1:-1]
        parameter_values[param] = raw_value
    return parameter_values

tool_registry

ToolRegistry

A registry for tools that can be called by the agent.

Tool definitions follow this structure:

{
    "type": "function",
    "function": {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
            },
            "required": ["location"],
        },
    },
}
Source code in nearai/agents/tool_registry.py
class ToolRegistry:
    """A registry for tools that can be called by the agent.

    Tool definitions follow this structure:

        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                    },
                    "required": ["location"],
                },
            },
        }

    """

    def __init__(self) -> None:  # noqa: D107
        self.tools: Dict[str, Callable] = {}

    def register_tool(self, tool: Callable) -> None:  # noqa: D102
        """Register a tool."""
        self.tools[tool.__name__] = tool

    def get_tool(self, name: str) -> Optional[Callable]:  # noqa: D102
        """Get a tool by name."""
        return self.tools.get(name)

    def get_all_tools(self) -> Dict[str, Callable]:  # noqa: D102
        """Get all tools."""
        return self.tools

    def call_tool(self, name: str, **kwargs: Any) -> Any:  # noqa: D102
        """Call a tool by name."""
        tool = self.get_tool(name)
        if tool is None:
            raise ValueError(f"Tool '{name}' not found.")
        return tool(**kwargs)

    def get_tool_definition(self, name: str) -> Optional[Dict]:  # noqa: D102
        """Get the definition of a tool by name."""
        tool = self.get_tool(name)
        if tool is None:
            return None

        assert tool.__doc__ is not None, f"Docstring missing for tool '{name}'."
        docstring = tool.__doc__.strip().split("\n")

        # The first line of the docstring is the function description
        function_description = docstring[0].strip()

        # The rest of the lines contain parameter descriptions
        param_descriptions = docstring[1:]

        # Extract parameter names and types
        signature = inspect.signature(tool)
        type_hints = get_type_hints(tool)

        parameters: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}

        # Iterate through function parameters
        for param in signature.parameters.values():
            param_name = param.name
            param_type = type_hints.get(param_name, str)  # Default to str if type hint is missing
            param_description = ""

            # Find the parameter description in the docstring
            for line in param_descriptions:
                if line.strip().startswith(param_name):
                    param_description = line.strip().split(":", 1)[1].strip()
                    break

            # Convert type hint to JSON Schema type
            if isinstance(param_type, _GenericAlias) and param_type.__origin__ is Literal:
                json_type = "string"
            else:
                json_type = param_type.__name__.lower()

            json_type = {"int": "integer", "float": "number", "str": "string", "bool": "boolean"}.get(
                json_type, json_type
            )

            # Add parameter to the definition
            parameters["properties"][param_name] = {"description": param_description, "type": json_type}

            # Params without default values are required params
            if param.default == inspect.Parameter.empty:
                parameters["required"].append(param_name)

        return {
            "type": "function",
            "function": {"name": tool.__name__, "description": function_description, "parameters": parameters},
        }

    def get_all_tool_definitions(self) -> list[Dict]:  # noqa: D102
        definitions = []
        for tool_name, _tool in self.tools.items():
            definition = self.get_tool_definition(tool_name)
            if definition is not None:
                definitions.append(definition)
        return definitions
call_tool
call_tool(name: str, **kwargs: Any) -> Any

Call a tool by name.

Source code in nearai/agents/tool_registry.py
def call_tool(self, name: str, **kwargs: Any) -> Any:  # noqa: D102
    """Call a tool by name."""
    tool = self.get_tool(name)
    if tool is None:
        raise ValueError(f"Tool '{name}' not found.")
    return tool(**kwargs)
get_all_tools
get_all_tools() -> Dict[str, Callable]

Get all tools.

Source code in nearai/agents/tool_registry.py
def get_all_tools(self) -> Dict[str, Callable]:  # noqa: D102
    """Get all tools."""
    return self.tools
get_tool
get_tool(name: str) -> Optional[Callable]

Get a tool by name.

Source code in nearai/agents/tool_registry.py
def get_tool(self, name: str) -> Optional[Callable]:  # noqa: D102
    """Get a tool by name."""
    return self.tools.get(name)
get_tool_definition
get_tool_definition(name: str) -> Optional[Dict]

Get the definition of a tool by name.

Source code in nearai/agents/tool_registry.py
def get_tool_definition(self, name: str) -> Optional[Dict]:  # noqa: D102
    """Get the definition of a tool by name."""
    tool = self.get_tool(name)
    if tool is None:
        return None

    assert tool.__doc__ is not None, f"Docstring missing for tool '{name}'."
    docstring = tool.__doc__.strip().split("\n")

    # The first line of the docstring is the function description
    function_description = docstring[0].strip()

    # The rest of the lines contain parameter descriptions
    param_descriptions = docstring[1:]

    # Extract parameter names and types
    signature = inspect.signature(tool)
    type_hints = get_type_hints(tool)

    parameters: Dict[str, Any] = {"type": "object", "properties": {}, "required": []}

    # Iterate through function parameters
    for param in signature.parameters.values():
        param_name = param.name
        param_type = type_hints.get(param_name, str)  # Default to str if type hint is missing
        param_description = ""

        # Find the parameter description in the docstring
        for line in param_descriptions:
            if line.strip().startswith(param_name):
                param_description = line.strip().split(":", 1)[1].strip()
                break

        # Convert type hint to JSON Schema type
        if isinstance(param_type, _GenericAlias) and param_type.__origin__ is Literal:
            json_type = "string"
        else:
            json_type = param_type.__name__.lower()

        json_type = {"int": "integer", "float": "number", "str": "string", "bool": "boolean"}.get(
            json_type, json_type
        )

        # Add parameter to the definition
        parameters["properties"][param_name] = {"description": param_description, "type": json_type}

        # Params without default values are required params
        if param.default == inspect.Parameter.empty:
            parameters["required"].append(param_name)

    return {
        "type": "function",
        "function": {"name": tool.__name__, "description": function_description, "parameters": parameters},
    }
register_tool
register_tool(tool: Callable) -> None

Register a tool.

Source code in nearai/agents/tool_registry.py
def register_tool(self, tool: Callable) -> None:  # noqa: D102
    """Register a tool."""
    self.tools[tool.__name__] = tool

cli

AgentCli

Source code in nearai/cli.py
class AgentCli:
    def dev(self) -> int:
        """Run local UI for development of agents that have their own UI."""
        if not os.path.exists("hub/demo/.env"):
            shutil.copy("hub/demo/.env.example", "hub/demo/.env")

        ret_val = os.system("npm install --prefix hub/demo")
        if ret_val != 0:
            print("Node.js is required to run the development server.")
            print("Please install Node.js from https://nodejs.org/")
        ret_val = os.system("npm run dev --prefix hub/demo")
        return ret_val

    def inspect(self, path: str) -> None:
        """Inspect environment from given path."""
        import subprocess

        filename = Path(os.path.abspath(__file__)).parent / "streamlit_inspect.py"
        subprocess.call(["streamlit", "run", filename, "--", path])

    def interactive(
        self,
        agent: str,
        thread_id: Optional[str] = None,
        tool_resources: Optional[Dict[str, Any]] = None,
        local: bool = False,
        env_vars: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Runs agent interactively."""
        last_message_id = None
        while True:
            new_message = input("> ")
            if new_message.lower() == "exit":
                break

            last_message_id = self._task(
                agent=agent,
                task=new_message,
                thread_id=thread_id,
                tool_resources=tool_resources,
                last_message_id=last_message_id,
                local=local,
                env_vars=env_vars,
            )

            # Update thread_id for the next iteration
            if thread_id is None:
                thread_id = self.last_thread_id

    def task(
        self,
        agent: str,
        task: str,
        thread_id: Optional[str] = None,
        tool_resources: Optional[Dict[str, Any]] = None,
        file_ids: Optional[List[str]] = None,
        local: bool = False,
        env_vars: Optional[Dict[str, Any]] = None,
    ) -> None:
        """CLI wrapper for the _task method."""
        last_message_id = self._task(
            agent=agent,
            task=task,
            thread_id=thread_id,
            tool_resources=tool_resources,
            file_ids=file_ids,
            local=local,
            env_vars=env_vars,
        )
        if last_message_id:
            print(f"Task completed. Thread ID: {self.last_thread_id}")
            print(f"Last message ID: {last_message_id}")

    def _task(
        self,
        agent: str,
        task: str,
        thread_id: Optional[str] = None,
        tool_resources: Optional[Dict[str, Any]] = None,
        file_ids: Optional[List[str]] = None,
        last_message_id: Optional[str] = None,
        local: bool = False,
        env_vars: Optional[Dict[str, Any]] = None,
    ) -> Optional[str]:
        """Runs agent non-interactively with a single task."""
        hub_client = get_hub_client()
        if thread_id:
            thread = hub_client.beta.threads.retrieve(thread_id)
        else:
            thread = hub_client.beta.threads.create(
                tool_resources=tool_resources,
            )

        hub_client.beta.threads.messages.create(
            thread_id=thread.id,
            role="user",
            content=task,
            attachments=[Attachment(file_id=file_id) for file_id in file_ids] if file_ids else None,
        )

        if not local:
            hub_client.beta.threads.runs.create_and_poll(
                thread_id=thread.id,
                assistant_id=agent,
            )
        else:
            run = hub_client.beta.threads.runs.create(
                thread_id=thread.id,
                assistant_id=agent,
                extra_body={"delegate_execution": True},
            )
            params = {
                "api_url": CONFIG.api_url,
                "tool_resources": run.tools,
                "data_source": "local_files",
                "user_env_vars": env_vars,
                "agent_env_vars": {},
            }
            auth = CONFIG.auth
            assert auth is not None
            LocalRunner(agent, agent, thread.id, run.id, auth, params)

        # List new messages
        messages = hub_client.beta.threads.messages.list(thread_id=thread.id, after=last_message_id, order="asc")
        message_list = list(messages)
        if message_list:
            for msg in message_list:
                if msg.metadata and msg.metadata.get("message_type"):
                    continue
                if msg.role == "assistant":
                    print(f"Assistant: {msg.content[0].text.value}")
            last_message_id = message_list[-1].id
        else:
            print("No new messages")

        # Store the thread_id for potential use in interactive mode
        self.last_thread_id = thread.id

        return last_message_id

    def create(self, name: Optional[str] = None, description: Optional[str] = None, fork: Optional[str] = None) -> None:
        """Create a new agent or fork an existing one.

        Usage:
          nearai agent create
          nearai agent create --name <agent_name> --description <description>
          nearai agent create --fork <namespace/agent_name/version> [--name <new_agent_name>]

        Options:
          --name          Name of the new agent.
          --description   Description of the new agent.
          --fork          Fork an existing agent specified by namespace/agent_name/version.

        Examples
        --------
          nearai agent create
          nearai agent create --name my_agent --description "My new agent"
          nearai agent create --fork agentic.near/summary/0.0.3 --name new_summary_agent

        """
        # Check if the user is authenticated
        if CONFIG.auth is None or CONFIG.auth.namespace is None:
            print("Please login with `nearai login` before creating an agent.")
            return

        namespace = CONFIG.auth.namespace

        if fork:
            # Fork an existing agent
            self._fork_agent(fork, namespace, name)
        else:
            # Create a new agent from scratch
            self._create_new_agent(namespace, name, description)

    def _create_new_agent(self, namespace: str, name: Optional[str], description: Optional[str]) -> None:
        """Create a new agent from scratch."""
        # Prompt for agent name if not provided
        if not name or not isinstance(name, str):
            name = input("Name: ").strip()
            while not name or not isinstance(name, str):
                print("Agent name cannot be empty.")
                name = input("Name: ").strip()

        # Prompt for description if not provided
        while not description or not isinstance(description, str):
            print("A description is needed for agent matching and cannot be empty.")
            description = input("Description: ").strip()

        # Set the agent path
        agent_path = get_registry_folder() / namespace / name / "0.0.1"
        agent_path.mkdir(parents=True, exist_ok=True)

        # Create metadata.json
        metadata = {
            "name": name,
            "version": "0.0.1",
            "description": description,
            "category": "agent",
            "tags": [],
            "details": {
                "agent": {
                    "defaults": {
                        "model": DEFAULT_MODEL,
                        "model_provider": DEFAULT_PROVIDER,
                        "model_temperature": DEFAULT_MODEL_TEMPERATURE,
                        "model_max_tokens": DEFAULT_MODEL_MAX_TOKENS,
                    }
                }
            },
            "show_entry": True,
        }

        metadata_path = agent_path / "metadata.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata, f, indent=2)

        # Create a default agent.py
        agent_py_content = """from nearai.agents.environment import Environment


def run(env: Environment):
    # Your agent code here
    # Example:
    prompt = {"role": "system", "content": "You are a helpful assistant."}
    result = env.completion([prompt] + env.list_messages())
    env.add_reply(result)
    env.request_user_input()

run(env)

"""
        agent_py_path = agent_path / "agent.py"
        with open(agent_py_path, "w") as f:
            f.write(agent_py_content)

        print(f"\nAgent created at: {agent_path}")
        print("Consider editing:")
        print(f"\t{agent_path}/agent.py")
        print(f"\t{agent_path}/metadata.json")
        print("\nUseful commands:")
        print(f"  > nearai agent interactive {agent_path} --local")
        print(f"  > nearai registry upload {agent_path}")

    def _fork_agent(self, fork: str, namespace: str, new_name: Optional[str]) -> None:
        """Fork an existing agent."""
        import shutil

        # Parse the fork parameter
        try:
            entry_location = parse_location(fork)
            fork_namespace = entry_location.namespace
            fork_name = entry_location.name
            fork_version = entry_location.version
        except ValueError:
            print("Invalid fork parameter format. Expected format: <namespace>/<agent-name>/<version>")
            return

        # Download the agent from the registry
        agent_location = f"{fork_namespace}/{fork_name}/{fork_version}"
        print(f"Downloading agent '{agent_location}'...")
        registry.download(agent_location, force=False, show_progress=True)
        source_path = get_registry_folder() / fork_namespace / fork_name / fork_version

        # Prompt for the new agent name if not provided
        if not new_name:
            new_name = input("Enter the new agent name: ").strip()
            if not new_name:
                print("Agent name cannot be empty.")
                return

            # confirm pattern is ok
            identifier_pattern = re.compile(r"^[a-zA-Z0-9_\-.]+$")
            if identifier_pattern.match(new_name) is None:
                print("Invalid Name, please choose something different")
                return

        # Set the destination path
        dest_path = get_registry_folder() / namespace / new_name / "0.0.1"

        # Copy the agent files
        shutil.copytree(source_path, dest_path)

        # Update metadata.json
        metadata_path = dest_path / "metadata.json"
        with open(metadata_path, "r") as file:
            metadata = json.load(file)

        metadata["name"] = new_name
        metadata["version"] = "0.0.1"

        with open(metadata_path, "w") as file:
            json.dump(metadata, file, indent=2)

        print(f"\nForked agent '{agent_location}' to '{dest_path}'")
        print(f"Agent '{new_name}' created at '{dest_path}' with updated metadata.")
        print("\nUseful commands:")
        print(f"  > nearai agent interactive {new_name} --local")
        print(f"  > nearai registry upload {dest_path}")
_create_new_agent
_create_new_agent(namespace: str, name: Optional[str], description: Optional[str]) -> None

Create a new agent from scratch.

Source code in nearai/cli.py
    def _create_new_agent(self, namespace: str, name: Optional[str], description: Optional[str]) -> None:
        """Create a new agent from scratch."""
        # Prompt for agent name if not provided
        if not name or not isinstance(name, str):
            name = input("Name: ").strip()
            while not name or not isinstance(name, str):
                print("Agent name cannot be empty.")
                name = input("Name: ").strip()

        # Prompt for description if not provided
        while not description or not isinstance(description, str):
            print("A description is needed for agent matching and cannot be empty.")
            description = input("Description: ").strip()

        # Set the agent path
        agent_path = get_registry_folder() / namespace / name / "0.0.1"
        agent_path.mkdir(parents=True, exist_ok=True)

        # Create metadata.json
        metadata = {
            "name": name,
            "version": "0.0.1",
            "description": description,
            "category": "agent",
            "tags": [],
            "details": {
                "agent": {
                    "defaults": {
                        "model": DEFAULT_MODEL,
                        "model_provider": DEFAULT_PROVIDER,
                        "model_temperature": DEFAULT_MODEL_TEMPERATURE,
                        "model_max_tokens": DEFAULT_MODEL_MAX_TOKENS,
                    }
                }
            },
            "show_entry": True,
        }

        metadata_path = agent_path / "metadata.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata, f, indent=2)

        # Create a default agent.py
        agent_py_content = """from nearai.agents.environment import Environment


def run(env: Environment):
    # Your agent code here
    # Example:
    prompt = {"role": "system", "content": "You are a helpful assistant."}
    result = env.completion([prompt] + env.list_messages())
    env.add_reply(result)
    env.request_user_input()

run(env)

"""
        agent_py_path = agent_path / "agent.py"
        with open(agent_py_path, "w") as f:
            f.write(agent_py_content)

        print(f"\nAgent created at: {agent_path}")
        print("Consider editing:")
        print(f"\t{agent_path}/agent.py")
        print(f"\t{agent_path}/metadata.json")
        print("\nUseful commands:")
        print(f"  > nearai agent interactive {agent_path} --local")
        print(f"  > nearai registry upload {agent_path}")
_fork_agent
_fork_agent(fork: str, namespace: str, new_name: Optional[str]) -> None

Fork an existing agent.

Source code in nearai/cli.py
def _fork_agent(self, fork: str, namespace: str, new_name: Optional[str]) -> None:
    """Fork an existing agent."""
    import shutil

    # Parse the fork parameter
    try:
        entry_location = parse_location(fork)
        fork_namespace = entry_location.namespace
        fork_name = entry_location.name
        fork_version = entry_location.version
    except ValueError:
        print("Invalid fork parameter format. Expected format: <namespace>/<agent-name>/<version>")
        return

    # Download the agent from the registry
    agent_location = f"{fork_namespace}/{fork_name}/{fork_version}"
    print(f"Downloading agent '{agent_location}'...")
    registry.download(agent_location, force=False, show_progress=True)
    source_path = get_registry_folder() / fork_namespace / fork_name / fork_version

    # Prompt for the new agent name if not provided
    if not new_name:
        new_name = input("Enter the new agent name: ").strip()
        if not new_name:
            print("Agent name cannot be empty.")
            return

        # confirm pattern is ok
        identifier_pattern = re.compile(r"^[a-zA-Z0-9_\-.]+$")
        if identifier_pattern.match(new_name) is None:
            print("Invalid Name, please choose something different")
            return

    # Set the destination path
    dest_path = get_registry_folder() / namespace / new_name / "0.0.1"

    # Copy the agent files
    shutil.copytree(source_path, dest_path)

    # Update metadata.json
    metadata_path = dest_path / "metadata.json"
    with open(metadata_path, "r") as file:
        metadata = json.load(file)

    metadata["name"] = new_name
    metadata["version"] = "0.0.1"

    with open(metadata_path, "w") as file:
        json.dump(metadata, file, indent=2)

    print(f"\nForked agent '{agent_location}' to '{dest_path}'")
    print(f"Agent '{new_name}' created at '{dest_path}' with updated metadata.")
    print("\nUseful commands:")
    print(f"  > nearai agent interactive {new_name} --local")
    print(f"  > nearai registry upload {dest_path}")
_task
_task(agent: str, task: str, thread_id: Optional[str] = None, tool_resources: Optional[Dict[str, Any]] = None, file_ids: Optional[List[str]] = None, last_message_id: Optional[str] = None, local: bool = False, env_vars: Optional[Dict[str, Any]] = None) -> Optional[str]

Runs agent non-interactively with a single task.

Source code in nearai/cli.py
def _task(
    self,
    agent: str,
    task: str,
    thread_id: Optional[str] = None,
    tool_resources: Optional[Dict[str, Any]] = None,
    file_ids: Optional[List[str]] = None,
    last_message_id: Optional[str] = None,
    local: bool = False,
    env_vars: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
    """Runs agent non-interactively with a single task."""
    hub_client = get_hub_client()
    if thread_id:
        thread = hub_client.beta.threads.retrieve(thread_id)
    else:
        thread = hub_client.beta.threads.create(
            tool_resources=tool_resources,
        )

    hub_client.beta.threads.messages.create(
        thread_id=thread.id,
        role="user",
        content=task,
        attachments=[Attachment(file_id=file_id) for file_id in file_ids] if file_ids else None,
    )

    if not local:
        hub_client.beta.threads.runs.create_and_poll(
            thread_id=thread.id,
            assistant_id=agent,
        )
    else:
        run = hub_client.beta.threads.runs.create(
            thread_id=thread.id,
            assistant_id=agent,
            extra_body={"delegate_execution": True},
        )
        params = {
            "api_url": CONFIG.api_url,
            "tool_resources": run.tools,
            "data_source": "local_files",
            "user_env_vars": env_vars,
            "agent_env_vars": {},
        }
        auth = CONFIG.auth
        assert auth is not None
        LocalRunner(agent, agent, thread.id, run.id, auth, params)

    # List new messages
    messages = hub_client.beta.threads.messages.list(thread_id=thread.id, after=last_message_id, order="asc")
    message_list = list(messages)
    if message_list:
        for msg in message_list:
            if msg.metadata and msg.metadata.get("message_type"):
                continue
            if msg.role == "assistant":
                print(f"Assistant: {msg.content[0].text.value}")
        last_message_id = message_list[-1].id
    else:
        print("No new messages")

    # Store the thread_id for potential use in interactive mode
    self.last_thread_id = thread.id

    return last_message_id
create
create(name: Optional[str] = None, description: Optional[str] = None, fork: Optional[str] = None) -> None

Create a new agent or fork an existing one.

Usage

nearai agent create nearai agent create --name --description nearai agent create --fork [--name ]

Options

--name Name of the new agent. --description Description of the new agent. --fork Fork an existing agent specified by namespace/agent_name/version.

Examples

nearai agent create nearai agent create --name my_agent --description "My new agent" nearai agent create --fork agentic.near/summary/0.0.3 --name new_summary_agent

Source code in nearai/cli.py
def create(self, name: Optional[str] = None, description: Optional[str] = None, fork: Optional[str] = None) -> None:
    """Create a new agent or fork an existing one.

    Usage:
      nearai agent create
      nearai agent create --name <agent_name> --description <description>
      nearai agent create --fork <namespace/agent_name/version> [--name <new_agent_name>]

    Options:
      --name          Name of the new agent.
      --description   Description of the new agent.
      --fork          Fork an existing agent specified by namespace/agent_name/version.

    Examples
    --------
      nearai agent create
      nearai agent create --name my_agent --description "My new agent"
      nearai agent create --fork agentic.near/summary/0.0.3 --name new_summary_agent

    """
    # Check if the user is authenticated
    if CONFIG.auth is None or CONFIG.auth.namespace is None:
        print("Please login with `nearai login` before creating an agent.")
        return

    namespace = CONFIG.auth.namespace

    if fork:
        # Fork an existing agent
        self._fork_agent(fork, namespace, name)
    else:
        # Create a new agent from scratch
        self._create_new_agent(namespace, name, description)
dev
dev() -> int

Run local UI for development of agents that have their own UI.

Source code in nearai/cli.py
def dev(self) -> int:
    """Run local UI for development of agents that have their own UI."""
    if not os.path.exists("hub/demo/.env"):
        shutil.copy("hub/demo/.env.example", "hub/demo/.env")

    ret_val = os.system("npm install --prefix hub/demo")
    if ret_val != 0:
        print("Node.js is required to run the development server.")
        print("Please install Node.js from https://nodejs.org/")
    ret_val = os.system("npm run dev --prefix hub/demo")
    return ret_val
inspect
inspect(path: str) -> None

Inspect environment from given path.

Source code in nearai/cli.py
def inspect(self, path: str) -> None:
    """Inspect environment from given path."""
    import subprocess

    filename = Path(os.path.abspath(__file__)).parent / "streamlit_inspect.py"
    subprocess.call(["streamlit", "run", filename, "--", path])
interactive
interactive(agent: str, thread_id: Optional[str] = None, tool_resources: Optional[Dict[str, Any]] = None, local: bool = False, env_vars: Optional[Dict[str, Any]] = None) -> None

Runs agent interactively.

Source code in nearai/cli.py
def interactive(
    self,
    agent: str,
    thread_id: Optional[str] = None,
    tool_resources: Optional[Dict[str, Any]] = None,
    local: bool = False,
    env_vars: Optional[Dict[str, Any]] = None,
) -> None:
    """Runs agent interactively."""
    last_message_id = None
    while True:
        new_message = input("> ")
        if new_message.lower() == "exit":
            break

        last_message_id = self._task(
            agent=agent,
            task=new_message,
            thread_id=thread_id,
            tool_resources=tool_resources,
            last_message_id=last_message_id,
            local=local,
            env_vars=env_vars,
        )

        # Update thread_id for the next iteration
        if thread_id is None:
            thread_id = self.last_thread_id
task
task(agent: str, task: str, thread_id: Optional[str] = None, tool_resources: Optional[Dict[str, Any]] = None, file_ids: Optional[List[str]] = None, local: bool = False, env_vars: Optional[Dict[str, Any]] = None) -> None

CLI wrapper for the _task method.

Source code in nearai/cli.py
def task(
    self,
    agent: str,
    task: str,
    thread_id: Optional[str] = None,
    tool_resources: Optional[Dict[str, Any]] = None,
    file_ids: Optional[List[str]] = None,
    local: bool = False,
    env_vars: Optional[Dict[str, Any]] = None,
) -> None:
    """CLI wrapper for the _task method."""
    last_message_id = self._task(
        agent=agent,
        task=task,
        thread_id=thread_id,
        tool_resources=tool_resources,
        file_ids=file_ids,
        local=local,
        env_vars=env_vars,
    )
    if last_message_id:
        print(f"Task completed. Thread ID: {self.last_thread_id}")
        print(f"Last message ID: {last_message_id}")

BenchmarkCli

Source code in nearai/cli.py
class BenchmarkCli:
    def __init__(self):
        """Initialize Benchmark API."""
        self.client = BenchmarkApi()

    def _get_or_create_benchmark(self, benchmark_name: str, solver_name: str, args: Dict[str, Any], force: bool) -> int:
        if CONFIG.auth is None:
            print("Please login with `nearai login`")
            exit(1)
        namespace = CONFIG.auth.namespace

        # Sort the args to have a consistent representation.
        solver_args = json.dumps(OrderedDict(sorted(args.items())))

        benchmark_id = self.client.get_benchmark_v1_benchmark_get_get(
            namespace=namespace,
            benchmark_name=benchmark_name,
            solver_name=solver_name,
            solver_args=solver_args,
        )

        if benchmark_id == -1 or force:
            benchmark_id = self.client.create_benchmark_v1_benchmark_create_get(
                benchmark_name=benchmark_name,
                solver_name=solver_name,
                solver_args=solver_args,
            )

        assert benchmark_id != -1
        return benchmark_id

    def run(
        self,
        dataset: str,
        solver_strategy: str,
        max_concurrent: int = 2,
        force: bool = False,
        subset: Optional[str] = None,
        check_compatibility: bool = True,
        record: bool = False,
        num_inference_retries: int = 10,
        **solver_args: Any,
    ) -> None:
        """Run benchmark on a dataset with a solver strategy.

        It will cache the results in the database and subsequent runs will pull the results from the cache.
        If force is set to True, it will run the benchmark again and update the cache.
        """
        from nearai.benchmark import BenchmarkExecutor, DatasetInfo
        from nearai.dataset import get_dataset, load_dataset
        from nearai.solvers import SolverScoringMethod, SolverStrategy, SolverStrategyRegistry

        CONFIG.num_inference_retries = num_inference_retries

        args = dict(solver_args)
        if subset is not None:
            args["subset"] = subset

        benchmark_id = self._get_or_create_benchmark(
            benchmark_name=dataset,
            solver_name=solver_strategy,
            args=args,
            force=force,
        )

        solver_strategy_class: Union[SolverStrategy, None] = SolverStrategyRegistry.get(solver_strategy, None)
        assert (
            solver_strategy_class
        ), f"Solver strategy {solver_strategy} not found. Available strategies: {list(SolverStrategyRegistry.keys())}"

        name = dataset
        if solver_strategy_class.scoring_method == SolverScoringMethod.Custom:
            dataset = str(get_dataset(dataset))
        else:
            dataset = load_dataset(dataset)

        solver_strategy_obj: SolverStrategy = solver_strategy_class(dataset_ref=dataset, **solver_args)  # type: ignore
        if check_compatibility:
            assert name in solver_strategy_obj.compatible_datasets() or any(
                map(lambda n: n in name, solver_strategy_obj.compatible_datasets())
            ), f"Solver strategy {solver_strategy} is not compatible with dataset {name}"

        dest_path = get_registry_folder() / name
        metadata_path = dest_path / "metadata.json"
        with open(metadata_path, "r") as file:
            metadata = json.load(file)

        be = BenchmarkExecutor(
            DatasetInfo(name, subset, dataset, metadata), solver_strategy_obj, benchmark_id=benchmark_id
        )

        cpu_count = os.cpu_count()
        max_concurrent = (cpu_count if cpu_count is not None else 1) if max_concurrent < 0 else max_concurrent
        be.run(max_concurrent=max_concurrent, record=record)

    def list(
        self,
        namespace: Optional[str] = None,
        benchmark: Optional[str] = None,
        solver: Optional[str] = None,
        args: Optional[str] = None,
        total: int = 32,
        offset: int = 0,
    ):
        """List all executed benchmarks."""
        result = self.client.list_benchmarks_v1_benchmark_list_get(
            namespace=namespace,
            benchmark_name=benchmark,
            solver_name=solver,
            solver_args=args,
            total=total,
            offset=offset,
        )

        header = ["id", "namespace", "benchmark", "solver", "args", "score", "solved", "total"]
        table = []
        for benchmark_output in result:
            score = 100 * benchmark_output.solved / benchmark_output.total
            table.append(
                [
                    fill(str(benchmark_output.id)),
                    fill(benchmark_output.namespace),
                    fill(benchmark_output.benchmark),
                    fill(benchmark_output.solver),
                    fill(benchmark_output.args),
                    fill(f"{score:.2f}%"),
                    fill(str(benchmark_output.solved)),
                    fill(str(benchmark_output.total)),
                ]
            )

        print(tabulate(table, headers=header, tablefmt="simple_grid"))
__init__
__init__()

Initialize Benchmark API.

Source code in nearai/cli.py
def __init__(self):
    """Initialize Benchmark API."""
    self.client = BenchmarkApi()
list
list(namespace: Optional[str] = None, benchmark: Optional[str] = None, solver: Optional[str] = None, args: Optional[str] = None, total: int = 32, offset: int = 0)

List all executed benchmarks.

Source code in nearai/cli.py
def list(
    self,
    namespace: Optional[str] = None,
    benchmark: Optional[str] = None,
    solver: Optional[str] = None,
    args: Optional[str] = None,
    total: int = 32,
    offset: int = 0,
):
    """List all executed benchmarks."""
    result = self.client.list_benchmarks_v1_benchmark_list_get(
        namespace=namespace,
        benchmark_name=benchmark,
        solver_name=solver,
        solver_args=args,
        total=total,
        offset=offset,
    )

    header = ["id", "namespace", "benchmark", "solver", "args", "score", "solved", "total"]
    table = []
    for benchmark_output in result:
        score = 100 * benchmark_output.solved / benchmark_output.total
        table.append(
            [
                fill(str(benchmark_output.id)),
                fill(benchmark_output.namespace),
                fill(benchmark_output.benchmark),
                fill(benchmark_output.solver),
                fill(benchmark_output.args),
                fill(f"{score:.2f}%"),
                fill(str(benchmark_output.solved)),
                fill(str(benchmark_output.total)),
            ]
        )

    print(tabulate(table, headers=header, tablefmt="simple_grid"))
run
run(dataset: str, solver_strategy: str, max_concurrent: int = 2, force: bool = False, subset: Optional[str] = None, check_compatibility: bool = True, record: bool = False, num_inference_retries: int = 10, **solver_args: Any) -> None

Run benchmark on a dataset with a solver strategy.

It will cache the results in the database and subsequent runs will pull the results from the cache. If force is set to True, it will run the benchmark again and update the cache.

Source code in nearai/cli.py
def run(
    self,
    dataset: str,
    solver_strategy: str,
    max_concurrent: int = 2,
    force: bool = False,
    subset: Optional[str] = None,
    check_compatibility: bool = True,
    record: bool = False,
    num_inference_retries: int = 10,
    **solver_args: Any,
) -> None:
    """Run benchmark on a dataset with a solver strategy.

    It will cache the results in the database and subsequent runs will pull the results from the cache.
    If force is set to True, it will run the benchmark again and update the cache.
    """
    from nearai.benchmark import BenchmarkExecutor, DatasetInfo
    from nearai.dataset import get_dataset, load_dataset
    from nearai.solvers import SolverScoringMethod, SolverStrategy, SolverStrategyRegistry

    CONFIG.num_inference_retries = num_inference_retries

    args = dict(solver_args)
    if subset is not None:
        args["subset"] = subset

    benchmark_id = self._get_or_create_benchmark(
        benchmark_name=dataset,
        solver_name=solver_strategy,
        args=args,
        force=force,
    )

    solver_strategy_class: Union[SolverStrategy, None] = SolverStrategyRegistry.get(solver_strategy, None)
    assert (
        solver_strategy_class
    ), f"Solver strategy {solver_strategy} not found. Available strategies: {list(SolverStrategyRegistry.keys())}"

    name = dataset
    if solver_strategy_class.scoring_method == SolverScoringMethod.Custom:
        dataset = str(get_dataset(dataset))
    else:
        dataset = load_dataset(dataset)

    solver_strategy_obj: SolverStrategy = solver_strategy_class(dataset_ref=dataset, **solver_args)  # type: ignore
    if check_compatibility:
        assert name in solver_strategy_obj.compatible_datasets() or any(
            map(lambda n: n in name, solver_strategy_obj.compatible_datasets())
        ), f"Solver strategy {solver_strategy} is not compatible with dataset {name}"

    dest_path = get_registry_folder() / name
    metadata_path = dest_path / "metadata.json"
    with open(metadata_path, "r") as file:
        metadata = json.load(file)

    be = BenchmarkExecutor(
        DatasetInfo(name, subset, dataset, metadata), solver_strategy_obj, benchmark_id=benchmark_id
    )

    cpu_count = os.cpu_count()
    max_concurrent = (cpu_count if cpu_count is not None else 1) if max_concurrent < 0 else max_concurrent
    be.run(max_concurrent=max_concurrent, record=record)

CLI

Source code in nearai/cli.py
class CLI:
    def __init__(self) -> None:  # noqa: D107
        self.registry = RegistryCli()
        self.login = LoginCLI()
        self.logout = LogoutCLI()
        self.hub = HubCLI()
        self.log = LogCLI()

        self.config = ConfigCli()
        self.benchmark = BenchmarkCli()
        self.evaluation = EvaluationCli()
        self.agent = AgentCli()
        self.finetune = FinetuneCli()
        self.tensorboard = TensorboardCli()
        self.vllm = VllmCli()
        self.permission = PermissionCli()

    def submit(self, path: Optional[str] = None, worker_kind: str = WorkerKind.GPU_8_A100.value):
        """Submit a task to be executed by a worker."""
        if path is None:
            path = os.getcwd()

        worker_kind_t = WorkerKind(worker_kind)

        location = self.registry.upload(path)

        delegation_api = DelegationApi()
        delegation_api.delegate_v1_delegation_delegate_post(
            delegate_account_id=CONFIG.scheduler_account_id,
            expires_at=datetime.now() + timedelta(days=1),
        )

        try:
            client = JobsApi()
            client.add_job_v1_jobs_add_job_post(
                worker_kind_t,
                BodyAddJobV1JobsAddJobPost(entry_location=location),
            )
        except Exception as e:
            print("Error: ", e)
            delegation_api.revoke_delegation_v1_delegation_revoke_delegation_post(
                delegate_account_id=CONFIG.scheduler_account_id,
            )

    def location(self) -> None:  # noqa: D102
        """Show location where nearai is installed."""
        from nearai import cli_path

        print(cli_path())

    def version(self):
        """Show nearai version."""
        print(importlib.metadata.version("nearai"))

    def task(self, *args, **kwargs):
        """CLI command for running a single task."""
        self.agent.task_cli(*args, **kwargs)
location
location() -> None

Show location where nearai is installed.

Source code in nearai/cli.py
def location(self) -> None:  # noqa: D102
    """Show location where nearai is installed."""
    from nearai import cli_path

    print(cli_path())
submit
submit(path: Optional[str] = None, worker_kind: str = WorkerKind.GPU_8_A100.value)

Submit a task to be executed by a worker.

Source code in nearai/cli.py
def submit(self, path: Optional[str] = None, worker_kind: str = WorkerKind.GPU_8_A100.value):
    """Submit a task to be executed by a worker."""
    if path is None:
        path = os.getcwd()

    worker_kind_t = WorkerKind(worker_kind)

    location = self.registry.upload(path)

    delegation_api = DelegationApi()
    delegation_api.delegate_v1_delegation_delegate_post(
        delegate_account_id=CONFIG.scheduler_account_id,
        expires_at=datetime.now() + timedelta(days=1),
    )

    try:
        client = JobsApi()
        client.add_job_v1_jobs_add_job_post(
            worker_kind_t,
            BodyAddJobV1JobsAddJobPost(entry_location=location),
        )
    except Exception as e:
        print("Error: ", e)
        delegation_api.revoke_delegation_v1_delegation_revoke_delegation_post(
            delegate_account_id=CONFIG.scheduler_account_id,
        )
task
task(*args, **kwargs)

CLI command for running a single task.

Source code in nearai/cli.py
def task(self, *args, **kwargs):
    """CLI command for running a single task."""
    self.agent.task_cli(*args, **kwargs)
version
version()

Show nearai version.

Source code in nearai/cli.py
def version(self):
    """Show nearai version."""
    print(importlib.metadata.version("nearai"))

ConfigCli

Source code in nearai/cli.py
class ConfigCli:
    def set(self, key: str, value: str, local: bool = False) -> None:
        """Add key-value pair to the config file."""
        update_config(key, value, local)

    def get(self, key: str) -> None:
        """Get value of a key in the config file."""
        print(CONFIG.get(key))

    def show(self) -> None:  # noqa: D102
        for key, value in asdict(CONFIG).items():
            print(f"{key}: {value}")
get
get(key: str) -> None

Get value of a key in the config file.

Source code in nearai/cli.py
def get(self, key: str) -> None:
    """Get value of a key in the config file."""
    print(CONFIG.get(key))
set
set(key: str, value: str, local: bool = False) -> None

Add key-value pair to the config file.

Source code in nearai/cli.py
def set(self, key: str, value: str, local: bool = False) -> None:
    """Add key-value pair to the config file."""
    update_config(key, value, local)

EvaluationCli

Source code in nearai/cli.py
class EvaluationCli:
    def table(
        self,
        all_key_columns: bool = False,
        all_metrics: bool = False,
        num_columns: int = 6,
        metric_name_max_length: int = 30,
    ) -> None:
        """Prints table of evaluations."""
        from nearai.evaluation import print_evaluation_table

        api = EvaluationApi()
        table = api.table_v1_evaluation_table_get()

        print_evaluation_table(
            table.rows,
            table.columns,
            table.important_columns,
            all_key_columns,
            all_metrics,
            num_columns,
            metric_name_max_length,
        )

    def read_solutions(self, entry: str, status: Optional[bool] = None, verbose: bool = False) -> None:
        """Reads solutions.json from evaluation entry."""
        entry_path = registry.download(entry)
        solutions_file = entry_path / "solutions.json"

        if not solutions_file.exists():
            print(f"No solutions file found for entry: {entry}")
            return

        try:
            with open(solutions_file) as f:
                solutions = json.load(f)
        except json.JSONDecodeError:
            print(f"Error reading solutions file for entry: {entry}")
            return

        # Filter solutions if status is specified
        if status is not None:
            solutions = [s for s in solutions if s.get("status") == status]
        if not solutions:
            print("No solutions found matching criteria")
            return
        print(f"\nFound {len(solutions)} solutions{' with status=' + str(status) if status is not None else ''}")

        for i, solution in enumerate(solutions, 1):
            print("-" * 80)
            print(f"\nSolution {i}/{len(solutions)}:")
            datum = solution.get("datum")
            print(f"datum: {json.dumps(datum, indent=2, ensure_ascii=False)}")
            status = solution.get("status")
            print(f"status: {status}")
            info: dict = solution.get("info", {})
            if not verbose:
                info.pop("verbose")
            print(f"info: {json.dumps(info, indent=2, ensure_ascii=False)}")
            if i == 1:
                print("Enter to continue, type 'exit' to quit.")
            new_message = input("> ")
            if new_message.lower() == "exit":
                break
read_solutions
read_solutions(entry: str, status: Optional[bool] = None, verbose: bool = False) -> None

Reads solutions.json from evaluation entry.

Source code in nearai/cli.py
def read_solutions(self, entry: str, status: Optional[bool] = None, verbose: bool = False) -> None:
    """Reads solutions.json from evaluation entry."""
    entry_path = registry.download(entry)
    solutions_file = entry_path / "solutions.json"

    if not solutions_file.exists():
        print(f"No solutions file found for entry: {entry}")
        return

    try:
        with open(solutions_file) as f:
            solutions = json.load(f)
    except json.JSONDecodeError:
        print(f"Error reading solutions file for entry: {entry}")
        return

    # Filter solutions if status is specified
    if status is not None:
        solutions = [s for s in solutions if s.get("status") == status]
    if not solutions:
        print("No solutions found matching criteria")
        return
    print(f"\nFound {len(solutions)} solutions{' with status=' + str(status) if status is not None else ''}")

    for i, solution in enumerate(solutions, 1):
        print("-" * 80)
        print(f"\nSolution {i}/{len(solutions)}:")
        datum = solution.get("datum")
        print(f"datum: {json.dumps(datum, indent=2, ensure_ascii=False)}")
        status = solution.get("status")
        print(f"status: {status}")
        info: dict = solution.get("info", {})
        if not verbose:
            info.pop("verbose")
        print(f"info: {json.dumps(info, indent=2, ensure_ascii=False)}")
        if i == 1:
            print("Enter to continue, type 'exit' to quit.")
        new_message = input("> ")
        if new_message.lower() == "exit":
            break
table
table(all_key_columns: bool = False, all_metrics: bool = False, num_columns: int = 6, metric_name_max_length: int = 30) -> None

Prints table of evaluations.

Source code in nearai/cli.py
def table(
    self,
    all_key_columns: bool = False,
    all_metrics: bool = False,
    num_columns: int = 6,
    metric_name_max_length: int = 30,
) -> None:
    """Prints table of evaluations."""
    from nearai.evaluation import print_evaluation_table

    api = EvaluationApi()
    table = api.table_v1_evaluation_table_get()

    print_evaluation_table(
        table.rows,
        table.columns,
        table.important_columns,
        all_key_columns,
        all_metrics,
        num_columns,
        metric_name_max_length,
    )

HubCLI

Source code in nearai/cli.py
class HubCLI:
    def chat(self, **kwargs):
        """Chat with model from NEAR AI hub.

        Args:
        ----
            query (str): User's query to model
            endpoint (str): NEAR AI HUB's url
            model (str): Name of a model
            provider (str): Name of a provider
            info (bool): Display system info
            kwargs (Dict[str, Any]): All cli keyword arguments

        """
        from nearai.hub import Hub

        hub = Hub(CONFIG)
        hub.chat(kwargs)
chat
chat(**kwargs)

Chat with model from NEAR AI hub.


query (str): User's query to model
endpoint (str): NEAR AI HUB's url
model (str): Name of a model
provider (str): Name of a provider
info (bool): Display system info
kwargs (Dict[str, Any]): All cli keyword arguments
Source code in nearai/cli.py
def chat(self, **kwargs):
    """Chat with model from NEAR AI hub.

    Args:
    ----
        query (str): User's query to model
        endpoint (str): NEAR AI HUB's url
        model (str): Name of a model
        provider (str): Name of a provider
        info (bool): Display system info
        kwargs (Dict[str, Any]): All cli keyword arguments

    """
    from nearai.hub import Hub

    hub = Hub(CONFIG)
    hub.chat(kwargs)

LoginCLI

Source code in nearai/cli.py
class LoginCLI:
    def __call__(self, **kwargs):
        """Login with NEAR Mainnet account.

        Args:
        ----
            remote (bool): Remote login allows signing message with NEAR Account on a remote machine
            auth_url (str): Url to the auth portal
            accountId (str): AccountId in .near-credentials folder to signMessage
            privateKey (str): Private Key to sign a message
            kwargs (Dict[str, Any]): All cli keyword arguments

        """
        from nearai.login import generate_and_save_signature, login_with_file_credentials, login_with_near_auth

        remote = kwargs.get("remote", False)
        account_id = kwargs.get("accountId", None)
        private_key = kwargs.get("privateKey", None)

        if not remote and account_id and private_key:
            generate_and_save_signature(account_id, private_key)
        elif not remote and account_id:
            login_with_file_credentials(account_id)
        else:
            auth_url = kwargs.get("auth_url", "https://auth.near.ai")
            login_with_near_auth(remote, auth_url)

    def status(self):
        """Load NEAR account authorization data."""
        from nearai.login import print_login_status

        print_login_status()

    def save(self, **kwargs):
        """Save NEAR account authorization data.

        Args:
        ----
            accountId (str): Near Account
            signature (str): Signature
            publicKey (str): Public Key used to sign
            callbackUrl (str): Callback Url
            nonce (str): nonce
            kwargs (Dict[str, Any]): All cli keyword arguments

        """
        from nearai.login import update_auth_config

        account_id = kwargs.get("accountId")
        signature = kwargs.get("signature")
        public_key = kwargs.get("publicKey")
        callback_url = kwargs.get("callbackUrl")
        nonce = kwargs.get("nonce")

        if account_id and signature and public_key and callback_url and nonce:
            update_auth_config(account_id, signature, public_key, callback_url, nonce)
        else:
            print("Missing data")
__call__
__call__(**kwargs)

Login with NEAR Mainnet account.


remote (bool): Remote login allows signing message with NEAR Account on a remote machine
auth_url (str): Url to the auth portal
accountId (str): AccountId in .near-credentials folder to signMessage
privateKey (str): Private Key to sign a message
kwargs (Dict[str, Any]): All cli keyword arguments
Source code in nearai/cli.py
def __call__(self, **kwargs):
    """Login with NEAR Mainnet account.

    Args:
    ----
        remote (bool): Remote login allows signing message with NEAR Account on a remote machine
        auth_url (str): Url to the auth portal
        accountId (str): AccountId in .near-credentials folder to signMessage
        privateKey (str): Private Key to sign a message
        kwargs (Dict[str, Any]): All cli keyword arguments

    """
    from nearai.login import generate_and_save_signature, login_with_file_credentials, login_with_near_auth

    remote = kwargs.get("remote", False)
    account_id = kwargs.get("accountId", None)
    private_key = kwargs.get("privateKey", None)

    if not remote and account_id and private_key:
        generate_and_save_signature(account_id, private_key)
    elif not remote and account_id:
        login_with_file_credentials(account_id)
    else:
        auth_url = kwargs.get("auth_url", "https://auth.near.ai")
        login_with_near_auth(remote, auth_url)
save
save(**kwargs)

Save NEAR account authorization data.


accountId (str): Near Account
signature (str): Signature
publicKey (str): Public Key used to sign
callbackUrl (str): Callback Url
nonce (str): nonce
kwargs (Dict[str, Any]): All cli keyword arguments
Source code in nearai/cli.py
def save(self, **kwargs):
    """Save NEAR account authorization data.

    Args:
    ----
        accountId (str): Near Account
        signature (str): Signature
        publicKey (str): Public Key used to sign
        callbackUrl (str): Callback Url
        nonce (str): nonce
        kwargs (Dict[str, Any]): All cli keyword arguments

    """
    from nearai.login import update_auth_config

    account_id = kwargs.get("accountId")
    signature = kwargs.get("signature")
    public_key = kwargs.get("publicKey")
    callback_url = kwargs.get("callbackUrl")
    nonce = kwargs.get("nonce")

    if account_id and signature and public_key and callback_url and nonce:
        update_auth_config(account_id, signature, public_key, callback_url, nonce)
    else:
        print("Missing data")
status
status()

Load NEAR account authorization data.

Source code in nearai/cli.py
def status(self):
    """Load NEAR account authorization data."""
    from nearai.login import print_login_status

    print_login_status()

LogoutCLI

Source code in nearai/cli.py
class LogoutCLI:
    def __call__(self, **kwargs):
        """Clear NEAR account auth data."""
        from nearai.config import load_config_file, save_config_file

        config = load_config_file()
        if not config.get("auth") or not config["auth"].get("account_id"):
            print("Auth data does not exist.")
        else:
            config.pop("auth", None)
            save_config_file(config)
            print("Auth data removed")
__call__
__call__(**kwargs)

Clear NEAR account auth data.

Source code in nearai/cli.py
def __call__(self, **kwargs):
    """Clear NEAR account auth data."""
    from nearai.config import load_config_file, save_config_file

    config = load_config_file()
    if not config.get("auth") or not config["auth"].get("account_id"):
        print("Auth data does not exist.")
    else:
        config.pop("auth", None)
        save_config_file(config)
        print("Auth data removed")

PermissionCli

Source code in nearai/cli.py
class PermissionCli:
    def __init__(self) -> None:  # noqa: D107
        self.client = PermissionsApi()

    def grant(self, account_id: str, permission: str):
        """Grant permission to an account."""
        self.client.grant_permission_v1_permissions_grant_permission_post(account_id, permission)

    def revoke(self, account_id: str, permission: str = ""):
        """Revoke permission from an account. If permission is empty all permissions are revoked."""
        self.client.revoke_permission_v1_permissions_revoke_permission_post(account_id, permission)
grant
grant(account_id: str, permission: str)

Grant permission to an account.

Source code in nearai/cli.py
def grant(self, account_id: str, permission: str):
    """Grant permission to an account."""
    self.client.grant_permission_v1_permissions_grant_permission_post(account_id, permission)
revoke
revoke(account_id: str, permission: str = '')

Revoke permission from an account. If permission is empty all permissions are revoked.

Source code in nearai/cli.py
def revoke(self, account_id: str, permission: str = ""):
    """Revoke permission from an account. If permission is empty all permissions are revoked."""
    self.client.revoke_permission_v1_permissions_revoke_permission_post(account_id, permission)

RegistryCli

Source code in nearai/cli.py
class RegistryCli:
    def info(self, entry: str) -> None:
        """Show information about an item."""
        entry_location = parse_location(entry)
        metadata = registry.info(entry_location)

        if metadata is None:
            print(f"Entry {entry} not found.")
            return

        print(metadata.model_dump_json(indent=2))
        if metadata.category == "model":
            available_provider_matches = ProviderModels(CONFIG.get_client_config()).available_provider_matches(
                NamespacedName(name=metadata.name, namespace=entry_location.namespace)
            )
            if len(available_provider_matches) > 0:
                header = ["provider", "name"]

                table = []
                for provider, name in available_provider_matches.items():
                    table.append(
                        [
                            fill(provider),
                            fill(name),
                        ]
                    )
                print(tabulate(table, headers=header, tablefmt="simple_grid"))

    def metadata_template(self, local_path: str = ".", category: str = "", description: str = ""):
        """Create a metadata template."""
        path = Path(local_path)

        metadata_path = path / "metadata.json"

        version = path.name
        pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$"  # noqa: E501
        assert re.match(pattern, version), f"Invalid semantic version format: {version}"
        name = path.parent.name
        assert not re.match(pattern, name), f"Invalid agent name: {name}"

        with open(metadata_path, "w") as f:
            metadata: Dict[str, Any] = {
                "name": name,
                "version": version,
                "description": description,
                "category": category,
                "tags": [],
                "details": {},
                "show_entry": True,
            }

            if category == "agent":
                metadata["details"]["agent"] = {}
                metadata["details"]["agent"]["welcome"] = {
                    "title": name,
                    "description": description,
                }
                metadata["details"]["agent"]["defaults"] = {
                    "model": DEFAULT_MODEL,
                    "model_provider": DEFAULT_PROVIDER,
                    "model_temperature": DEFAULT_MODEL_TEMPERATURE,
                    "model_max_tokens": DEFAULT_MODEL_MAX_TOKENS,
                    "max_iterations": 1,
                }
                metadata["details"]["agent"]["framework"] = "base"

            json.dump(metadata, f, indent=2)

    def list(
        self,
        namespace: str = "",
        category: str = "",
        tags: str = "",
        total: int = 32,
        offset: int = 0,
        show_all: bool = False,
        show_latest_version: bool = True,
        star: str = "",
    ) -> None:
        """List available items."""
        # Make sure tags is a comma-separated list of tags
        tags_l = parse_tags(tags)
        tags = ",".join(tags_l)

        entries = registry.list(
            namespace=namespace,
            category=category,
            tags=tags,
            total=total + 1,
            offset=offset,
            show_all=show_all,
            show_latest_version=show_latest_version,
            starred_by=star,
        )

        more_rows = len(entries) > total
        entries = entries[:total]

        header = ["entry", "category", "description", "tags"]

        table = []
        for entry in entries:
            table.append(
                [
                    fill(f"{entry.namespace}/{entry.name}/{entry.version}"),
                    fill(entry.category, 20),
                    fill(entry.description, 50),
                    fill(", ".join(entry.tags), 20),
                ]
            )

        if more_rows:
            table.append(["...", "...", "...", "..."])

        print(tabulate(table, headers=header, tablefmt="simple_grid"))

        if category == "model" and len(entries) < total and namespace == "" and tags == "" and star == "":
            unregistered_common_provider_models = ProviderModels(
                CONFIG.get_client_config()
            ).get_unregistered_common_provider_models(registry.dict_models())
            if len(unregistered_common_provider_models):
                print(
                    f"There are unregistered common provider models: {unregistered_common_provider_models}. Run 'nearai registry upload-unregistered-common-provider-models' to update registry."  # noqa: E501
                )

    def update(self, local_path: str = ".") -> None:
        """Update metadata of a registry item."""
        path = Path(local_path)

        if CONFIG.auth is None:
            print("Please login with `nearai login`")
            exit(1)

        metadata_path = path / "metadata.json"
        check_metadata(metadata_path)

        with open(metadata_path) as f:
            metadata: Dict[str, Any] = json.load(f)

        namespace = CONFIG.auth.namespace

        entry_location = EntryLocation.model_validate(
            dict(
                namespace=namespace,
                name=metadata.pop("name"),
                version=metadata.pop("version"),
            )
        )

        entry_metadata = EntryMetadataInput.model_validate(metadata)
        result = registry.update(entry_location, entry_metadata)
        print(json.dumps(result, indent=2))

    def upload_unregistered_common_provider_models(self, dry_run: bool = True) -> None:
        """Creates new registry items for unregistered common provider models."""
        provider_matches_list = ProviderModels(CONFIG.get_client_config()).get_unregistered_common_provider_models(
            registry.dict_models()
        )
        if len(provider_matches_list) == 0:
            print("No new models to upload.")
            return

        print("Going to create new registry items:")
        header = ["entry", "description"]
        table = []
        paths = []
        for provider_matches in provider_matches_list:
            provider_model = provider_matches.get(DEFAULT_PROVIDER) or next(iter(provider_matches.values()))
            _, model = get_provider_namespaced_model(provider_model)
            assert model.namespace == ""
            model.name = create_registry_name(model.name)
            model.namespace = DEFAULT_NAMESPACE
            version = "1.0.0"
            description = " & ".join(provider_matches.values())
            table.append(
                [
                    fill(f"{model.namespace}/{model.name}/{version}"),
                    fill(description, 50),
                ]
            )

            path = get_registry_folder() / model.namespace / model.name / version
            path.mkdir(parents=True, exist_ok=True)
            paths.append(path)
            metadata_path = path / "metadata.json"
            with open(metadata_path, "w") as f:
                metadata: Dict[str, Any] = {
                    "name": model.name,
                    "version": version,
                    "description": description,
                    "category": "model",
                    "tags": [],
                    "details": {},
                    "show_entry": True,
                }
                json.dump(metadata, f, indent=2)

        print(tabulate(table, headers=header, tablefmt="simple_grid"))
        if dry_run:
            print("Please verify, then repeat the command with --dry_run=False")
        else:
            for path in paths:
                self.upload(str(path))

    def upload(self, local_path: str = ".") -> EntryLocation:
        """Upload item to the registry."""
        return registry.upload(Path(local_path), show_progress=True)

    def download(self, entry_location: str, force: bool = False) -> None:
        """Download item."""
        registry.download(entry_location, force=force, show_progress=True)
download
download(entry_location: str, force: bool = False) -> None

Download item.

Source code in nearai/cli.py
def download(self, entry_location: str, force: bool = False) -> None:
    """Download item."""
    registry.download(entry_location, force=force, show_progress=True)
info
info(entry: str) -> None

Show information about an item.

Source code in nearai/cli.py
def info(self, entry: str) -> None:
    """Show information about an item."""
    entry_location = parse_location(entry)
    metadata = registry.info(entry_location)

    if metadata is None:
        print(f"Entry {entry} not found.")
        return

    print(metadata.model_dump_json(indent=2))
    if metadata.category == "model":
        available_provider_matches = ProviderModels(CONFIG.get_client_config()).available_provider_matches(
            NamespacedName(name=metadata.name, namespace=entry_location.namespace)
        )
        if len(available_provider_matches) > 0:
            header = ["provider", "name"]

            table = []
            for provider, name in available_provider_matches.items():
                table.append(
                    [
                        fill(provider),
                        fill(name),
                    ]
                )
            print(tabulate(table, headers=header, tablefmt="simple_grid"))
list
list(namespace: str = '', category: str = '', tags: str = '', total: int = 32, offset: int = 0, show_all: bool = False, show_latest_version: bool = True, star: str = '') -> None

List available items.

Source code in nearai/cli.py
def list(
    self,
    namespace: str = "",
    category: str = "",
    tags: str = "",
    total: int = 32,
    offset: int = 0,
    show_all: bool = False,
    show_latest_version: bool = True,
    star: str = "",
) -> None:
    """List available items."""
    # Make sure tags is a comma-separated list of tags
    tags_l = parse_tags(tags)
    tags = ",".join(tags_l)

    entries = registry.list(
        namespace=namespace,
        category=category,
        tags=tags,
        total=total + 1,
        offset=offset,
        show_all=show_all,
        show_latest_version=show_latest_version,
        starred_by=star,
    )

    more_rows = len(entries) > total
    entries = entries[:total]

    header = ["entry", "category", "description", "tags"]

    table = []
    for entry in entries:
        table.append(
            [
                fill(f"{entry.namespace}/{entry.name}/{entry.version}"),
                fill(entry.category, 20),
                fill(entry.description, 50),
                fill(", ".join(entry.tags), 20),
            ]
        )

    if more_rows:
        table.append(["...", "...", "...", "..."])

    print(tabulate(table, headers=header, tablefmt="simple_grid"))

    if category == "model" and len(entries) < total and namespace == "" and tags == "" and star == "":
        unregistered_common_provider_models = ProviderModels(
            CONFIG.get_client_config()
        ).get_unregistered_common_provider_models(registry.dict_models())
        if len(unregistered_common_provider_models):
            print(
                f"There are unregistered common provider models: {unregistered_common_provider_models}. Run 'nearai registry upload-unregistered-common-provider-models' to update registry."  # noqa: E501
            )
metadata_template
metadata_template(local_path: str = '.', category: str = '', description: str = '')

Create a metadata template.

Source code in nearai/cli.py
def metadata_template(self, local_path: str = ".", category: str = "", description: str = ""):
    """Create a metadata template."""
    path = Path(local_path)

    metadata_path = path / "metadata.json"

    version = path.name
    pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$"  # noqa: E501
    assert re.match(pattern, version), f"Invalid semantic version format: {version}"
    name = path.parent.name
    assert not re.match(pattern, name), f"Invalid agent name: {name}"

    with open(metadata_path, "w") as f:
        metadata: Dict[str, Any] = {
            "name": name,
            "version": version,
            "description": description,
            "category": category,
            "tags": [],
            "details": {},
            "show_entry": True,
        }

        if category == "agent":
            metadata["details"]["agent"] = {}
            metadata["details"]["agent"]["welcome"] = {
                "title": name,
                "description": description,
            }
            metadata["details"]["agent"]["defaults"] = {
                "model": DEFAULT_MODEL,
                "model_provider": DEFAULT_PROVIDER,
                "model_temperature": DEFAULT_MODEL_TEMPERATURE,
                "model_max_tokens": DEFAULT_MODEL_MAX_TOKENS,
                "max_iterations": 1,
            }
            metadata["details"]["agent"]["framework"] = "base"

        json.dump(metadata, f, indent=2)
update
update(local_path: str = '.') -> None

Update metadata of a registry item.

Source code in nearai/cli.py
def update(self, local_path: str = ".") -> None:
    """Update metadata of a registry item."""
    path = Path(local_path)

    if CONFIG.auth is None:
        print("Please login with `nearai login`")
        exit(1)

    metadata_path = path / "metadata.json"
    check_metadata(metadata_path)

    with open(metadata_path) as f:
        metadata: Dict[str, Any] = json.load(f)

    namespace = CONFIG.auth.namespace

    entry_location = EntryLocation.model_validate(
        dict(
            namespace=namespace,
            name=metadata.pop("name"),
            version=metadata.pop("version"),
        )
    )

    entry_metadata = EntryMetadataInput.model_validate(metadata)
    result = registry.update(entry_location, entry_metadata)
    print(json.dumps(result, indent=2))
upload
upload(local_path: str = '.') -> EntryLocation

Upload item to the registry.

Source code in nearai/cli.py
def upload(self, local_path: str = ".") -> EntryLocation:
    """Upload item to the registry."""
    return registry.upload(Path(local_path), show_progress=True)
upload_unregistered_common_provider_models
upload_unregistered_common_provider_models(dry_run: bool = True) -> None

Creates new registry items for unregistered common provider models.

Source code in nearai/cli.py
def upload_unregistered_common_provider_models(self, dry_run: bool = True) -> None:
    """Creates new registry items for unregistered common provider models."""
    provider_matches_list = ProviderModels(CONFIG.get_client_config()).get_unregistered_common_provider_models(
        registry.dict_models()
    )
    if len(provider_matches_list) == 0:
        print("No new models to upload.")
        return

    print("Going to create new registry items:")
    header = ["entry", "description"]
    table = []
    paths = []
    for provider_matches in provider_matches_list:
        provider_model = provider_matches.get(DEFAULT_PROVIDER) or next(iter(provider_matches.values()))
        _, model = get_provider_namespaced_model(provider_model)
        assert model.namespace == ""
        model.name = create_registry_name(model.name)
        model.namespace = DEFAULT_NAMESPACE
        version = "1.0.0"
        description = " & ".join(provider_matches.values())
        table.append(
            [
                fill(f"{model.namespace}/{model.name}/{version}"),
                fill(description, 50),
            ]
        )

        path = get_registry_folder() / model.namespace / model.name / version
        path.mkdir(parents=True, exist_ok=True)
        paths.append(path)
        metadata_path = path / "metadata.json"
        with open(metadata_path, "w") as f:
            metadata: Dict[str, Any] = {
                "name": model.name,
                "version": version,
                "description": description,
                "category": "model",
                "tags": [],
                "details": {},
                "show_entry": True,
            }
            json.dump(metadata, f, indent=2)

    print(tabulate(table, headers=header, tablefmt="simple_grid"))
    if dry_run:
        print("Please verify, then repeat the command with --dry_run=False")
    else:
        for path in paths:
            self.upload(str(path))

check_update

check_update()

Check if there is a new version of nearai CLI available.

Source code in nearai/cli.py
def check_update():
    """Check if there is a new version of nearai CLI available."""
    try:
        api = DefaultApi()
        latest = api.version_v1_version_get()
        current = importlib.metadata.version("nearai")

        if latest != current:
            print(f"New version of nearai CLI available: {latest}. Current version: {current}")
            print("Run `pip install --upgrade nearai` to update.")

    except Exception as _:
        pass

config

Config

Bases: BaseModel

Source code in nearai/config.py
class Config(BaseModel):
    origin: Optional[str] = None
    api_url: Optional[str] = "https://api.near.ai"
    inference_url: str = "http://localhost:5000/v1/"
    inference_api_key: str = "n/a"
    scheduler_account_id: str = "nearaischeduler.near"
    nearai_hub: NearAiHubConfig = NearAiHubConfig()
    confirm_commands: bool = True
    auth: Optional[AuthData] = None
    num_inference_retries: int = 1

    def update_with(self, extra_config: Dict[str, Any], map_key: Callable[[str], str] = lambda x: x) -> "Config":
        """Update the config with the given dictionary."""
        dict_repr = self.model_dump()
        keys = list(map(map_key, dict_repr.keys()))

        for key in keys:
            value = extra_config.get(key, None)

            if value:
                # This will skip empty values, even if they are set in the `extra_config`
                dict_repr[key] = value

        return Config.model_validate(dict_repr)

    def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
        """Get the value of a key in the config if it exists."""
        return getattr(self, key, default)

    def get_client_config(self) -> ClientConfig:  # noqa: D102
        return ClientConfig(
            base_url=CONFIG.nearai_hub.base_url,
            auth=CONFIG.auth,
            custom_llm_provider=CONFIG.nearai_hub.custom_llm_provider,
            default_provider=CONFIG.nearai_hub.default_provider,
            num_inference_retries=CONFIG.num_inference_retries,
        )
get
get(key: str, default: Optional[Any] = None) -> Optional[Any]

Get the value of a key in the config if it exists.

Source code in nearai/config.py
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
    """Get the value of a key in the config if it exists."""
    return getattr(self, key, default)
update_with
update_with(extra_config: Dict[str, Any], map_key: Callable[[str], str] = lambda x: x) -> Config

Update the config with the given dictionary.

Source code in nearai/config.py
def update_with(self, extra_config: Dict[str, Any], map_key: Callable[[str], str] = lambda x: x) -> "Config":
    """Update the config with the given dictionary."""
    dict_repr = self.model_dump()
    keys = list(map(map_key, dict_repr.keys()))

    for key in keys:
        value = extra_config.get(key, None)

        if value:
            # This will skip empty values, even if they are set in the `extra_config`
            dict_repr[key] = value

    return Config.model_validate(dict_repr)

NearAiHubConfig

Bases: BaseModel

NearAiHub Config.

login_with_near (Optional[bool]): Indicates whether to attempt login using Near Auth.

api_key (Optional[str]): The API key to use if Near Auth is not being utilized

base_url (Optional[str]): NEAR AI Hub url

default_provider (Optional[str]): Default provider name

default_model (Optional[str]): Default model name

custom_llm_provider (Optional[str]): provider to be used by litellm proxy

Source code in nearai/config.py
class NearAiHubConfig(BaseModel):
    """NearAiHub Config.

    login_with_near (Optional[bool]): Indicates whether to attempt login using Near Auth.

    api_key (Optional[str]): The API key to use if Near Auth is not being utilized

    base_url (Optional[str]): NEAR AI Hub url

    default_provider (Optional[str]): Default provider name

    default_model (Optional[str]): Default model name

    custom_llm_provider (Optional[str]): provider to be used by litellm proxy
    """

    base_url: str = "https://api.near.ai/v1"
    default_provider: str = DEFAULT_PROVIDER
    default_model: str = DEFAULT_PROVIDER_MODEL
    custom_llm_provider: str = "openai"
    login_with_near: Optional[bool] = True
    api_key: Optional[str] = ""

dataset

get_dataset

get_dataset(name: str, verbose: bool = True) -> Path

Download the dataset from the registry and download it locally if it hasn't been downloaded yet.

:param name: The name of the entry to download the dataset. The format should be namespace/name/version. :return: The path to the downloaded dataset

Source code in nearai/dataset.py
def get_dataset(name: str, verbose: bool = True) -> Path:
    """Download the dataset from the registry and download it locally if it hasn't been downloaded yet.

    :param name: The name of the entry to download the dataset. The format should be namespace/name/version.
    :return: The path to the downloaded dataset
    """
    return registry.download(name, verbose=verbose)

load_dataset

load_dataset(alias_or_name: str, verbose: bool = True) -> Union[Dataset, DatasetDict]

Load a dataset from the registry.

Source code in nearai/dataset.py
def load_dataset(alias_or_name: str, verbose: bool = True) -> Union[Dataset, DatasetDict]:
    """Load a dataset from the registry."""
    path = get_dataset(alias_or_name, verbose=verbose)
    return load_from_disk(path.as_posix())

delegation

OnBehalfOf

Create a context manager that allows you to delegate actions to another account.

with OnBehalfOf("scheduler.ai"):
    # Upload is done on behalf of scheduler.ai
    # If delegation permission is not granted, this will raise an exception
    registry.upload()
Source code in nearai/delegation.py
class OnBehalfOf:
    """Create a context manager that allows you to delegate actions to another account.

    ```python
    with OnBehalfOf("scheduler.ai"):
        # Upload is done on behalf of scheduler.ai
        # If delegation permission is not granted, this will raise an exception
        registry.upload()
    ```
    """

    def __init__(self, on_behalf_of: str):
        """Context manager that creates a scope where all actions are done on behalf of another account."""
        self.target_on_behalf_of = on_behalf_of
        self.original_access_token = None

    def __enter__(self):
        """Set the default client to the account we are acting on behalf of."""
        default_client = ApiClient.get_default()
        self.original_access_token = default_client.configuration.access_token

        if not isinstance(self.original_access_token, str):
            return

        assert self.original_access_token.startswith("Bearer ")
        auth = self.original_access_token[len("Bearer ") :]
        auth_data = AuthData.model_validate_json(auth)
        auth_data.on_behalf_of = self.target_on_behalf_of
        new_access_token = f"Bearer {auth_data.generate_bearer_token()}"
        default_client.configuration.access_token = new_access_token

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Reset the default client to the original account."""
        default_client = ApiClient.get_default()
        default_client.configuration.access_token = self.original_access_token
        self.original_access_token = None
__enter__
__enter__()

Set the default client to the account we are acting on behalf of.

Source code in nearai/delegation.py
def __enter__(self):
    """Set the default client to the account we are acting on behalf of."""
    default_client = ApiClient.get_default()
    self.original_access_token = default_client.configuration.access_token

    if not isinstance(self.original_access_token, str):
        return

    assert self.original_access_token.startswith("Bearer ")
    auth = self.original_access_token[len("Bearer ") :]
    auth_data = AuthData.model_validate_json(auth)
    auth_data.on_behalf_of = self.target_on_behalf_of
    new_access_token = f"Bearer {auth_data.generate_bearer_token()}"
    default_client.configuration.access_token = new_access_token
__exit__
__exit__(exc_type, exc_val, exc_tb)

Reset the default client to the original account.

Source code in nearai/delegation.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Reset the default client to the original account."""
    default_client = ApiClient.get_default()
    default_client.configuration.access_token = self.original_access_token
    self.original_access_token = None
__init__
__init__(on_behalf_of: str)

Context manager that creates a scope where all actions are done on behalf of another account.

Source code in nearai/delegation.py
def __init__(self, on_behalf_of: str):
    """Context manager that creates a scope where all actions are done on behalf of another account."""
    self.target_on_behalf_of = on_behalf_of
    self.original_access_token = None

check_on_behalf_of

check_on_behalf_of()

Check if the request is being made on behalf of another account.

Source code in nearai/delegation.py
def check_on_behalf_of():
    """Check if the request is being made on behalf of another account."""
    api = DelegationApi()
    return api.api_client.configuration.access_token

revoke_delegation

revoke_delegation(delegate_account_id: str)

Revoke delegation to a specific account.

Source code in nearai/delegation.py
def revoke_delegation(delegate_account_id: str):
    """Revoke delegation to a specific account."""
    DelegationApi().revoke_delegation_v1_delegation_revoke_delegation_post(delegate_account_id)

evaluation

_print_metrics_tables

_print_metrics_tables(rows: List[Dict[str, str]], metric_names: List[str], num_columns: int, all_key_columns: bool, metric_name_max_length: int)

Builds table(s) and prints them.

Source code in nearai/evaluation.py
def _print_metrics_tables(
    rows: List[Dict[str, str]],
    metric_names: List[str],
    num_columns: int,
    all_key_columns: bool,
    metric_name_max_length: int,
):
    """Builds table(s) and prints them."""
    # Shorten metric names
    short_metric_names = [_shorten_metric_name(name, metric_name_max_length) for name in metric_names]

    # Prepare the base header and rows
    base_header = ["model", "agent"]
    if all_key_columns:
        base_header.extend(["namespace", "version", "provider"])

    base_rows = []
    for row in rows:
        base_row = [fill(row.pop("model", "")), fill(row.pop("agent", ""))]
        namespace = row.pop("namespace", "")
        version = row.pop("version", "")
        provider = row.pop("provider", "")
        if all_key_columns:
            base_row.extend([fill(namespace), fill(version), fill(provider)])
        base_rows.append((base_row, row))

    n_metrics_per_table = max(1, num_columns - len(base_header))
    # Split metrics into groups
    metric_groups = list(
        zip(
            [
                short_metric_names[i : i + n_metrics_per_table]
                for i in range(0, len(short_metric_names), n_metrics_per_table)
            ],
            [metric_names[i : i + n_metrics_per_table] for i in range(0, len(metric_names), n_metrics_per_table)],
        )
    )

    # Print tables
    for short_group, full_group in metric_groups:
        header = base_header + short_group
        table = []
        for base_row, row_metrics in base_rows:
            table_row = base_row + [fill(str(row_metrics.get(metric, ""))) for metric in full_group]
            table.append(table_row)
        print(tabulate(table, headers=header, tablefmt="simple_grid"))

_shorten_metric_name

_shorten_metric_name(name: str, max_length: int) -> str

Shortens metric name if needed.

Source code in nearai/evaluation.py
def _shorten_metric_name(name: str, max_length: int) -> str:
    """Shortens metric name if needed."""
    if len(name) <= max_length:
        return name
    keep = max_length - 2  # 2 dots
    beginning = keep // 3
    ending = keep - beginning
    return name[:beginning] + ".." + name[-ending:]

load_benchmark_entry_info

load_benchmark_entry_info(info: str) -> Any

Deserializes benchmark info entry from db data.

Source code in nearai/evaluation.py
def load_benchmark_entry_info(info: str) -> Any:
    """Deserializes benchmark info entry from db data."""
    first_decode = json.loads(info)
    try:
        second_decode = json.loads(first_decode)
        return second_decode
    except json.JSONDecodeError as e:
        if "Unterminated string" in str(e):
            last_brace = first_decode.rfind("}")
            if last_brace != -1:
                try:
                    return json.loads(first_decode[: last_brace + 1])
                except json.JSONDecodeError as e:
                    pass
    return first_decode

print_evaluation_table

print_evaluation_table(rows: List[Dict[str, str]], columns: List[str], important_columns: List[str], all_key_columns: bool, all_metrics: bool, num_columns: int, metric_name_max_length: int) -> None

Prints table of evaluations.

Source code in nearai/evaluation.py
def print_evaluation_table(
    rows: List[Dict[str, str]],
    columns: List[str],
    important_columns: List[str],
    all_key_columns: bool,
    all_metrics: bool,
    num_columns: int,
    metric_name_max_length: int,
) -> None:
    """Prints table of evaluations."""
    metric_names = columns[5:] if all_metrics else important_columns[2:]
    _print_metrics_tables(rows, metric_names, num_columns, all_key_columns, metric_name_max_length)

record_evaluation_metrics

record_evaluation_metrics(solver_strategy: SolverStrategy, benchmark_id: int, data_tasks: Union[Dataset, List[dict]], metrics: Dict[str, Any], prepend_evaluation_name: bool = True) -> None

Uploads evaluation metrics into registry.

Source code in nearai/evaluation.py
def record_evaluation_metrics(
    solver_strategy: SolverStrategy,
    benchmark_id: int,
    data_tasks: Union[Dataset, List[dict]],
    metrics: Dict[str, Any],
    prepend_evaluation_name: bool = True,
) -> None:
    """Uploads evaluation metrics into registry."""
    evaluation_name = solver_strategy.evaluation_name()
    model = ""
    agent = ""
    version = ""
    model = solver_strategy.model_name
    agent = solver_strategy.agent_name()
    version = solver_strategy.agent_version()

    upload_evaluation(
        evaluation_name,
        benchmark_id,
        data_tasks,
        metrics if not prepend_evaluation_name else _prepend_name_to_metrics(evaluation_name, metrics),
        model,
        agent,
        solver_strategy.evaluated_entry_namespace(),
        version,
        solver_strategy.model_provider(),
    )

record_single_score_evaluation

record_single_score_evaluation(solver_strategy: SolverStrategy, benchmark_id: int, data_tasks: Union[Dataset, List[dict]], score: float) -> None

Uploads single score evaluation into registry.

Source code in nearai/evaluation.py
def record_single_score_evaluation(
    solver_strategy: SolverStrategy, benchmark_id: int, data_tasks: Union[Dataset, List[dict]], score: float
) -> None:
    """Uploads single score evaluation into registry."""
    evaluation_name = solver_strategy.evaluation_name()
    record_evaluation_metrics(solver_strategy, benchmark_id, data_tasks, {evaluation_name: score}, False)

upload_evaluation

upload_evaluation(evaluation_name: str, benchmark_id: int, data_tasks: Union[Dataset, List[dict]], metrics: Dict[str, Any], model: str = '', agent: str = '', namespace: str = '', version: str = '', provider: str = '') -> None

Uploads evaluation into registry.

evaluation_name: a unique name for (benchmark, solver) tuple, e.g. "mbpp" or "live_bench" or "mmlu-5-shot". metrics: metrics from evaluation. model: model that was used. agent: agent that was evaluated, in any. namespace: namespace of evaluated agent or evaluated model. version: version of evaluated agent or evaluated model. provider: provider of model used; pass local if running locally.

Source code in nearai/evaluation.py
def upload_evaluation(
    evaluation_name: str,
    benchmark_id: int,
    data_tasks: Union[Dataset, List[dict]],
    metrics: Dict[str, Any],
    model: str = "",
    agent: str = "",
    namespace: str = "",
    version: str = "",
    provider: str = "",
) -> None:
    """Uploads evaluation into registry.

    `evaluation_name`: a unique name for (benchmark, solver) tuple, e.g. "mbpp" or "live_bench" or "mmlu-5-shot".
    `metrics`: metrics from evaluation.
    `model`: model that was used.
    `agent`: agent that was evaluated, in any.
    `namespace`: namespace of evaluated agent or evaluated model.
    `version`: version of evaluated agent or evaluated model.
    `provider`: provider of model used; pass `local` if running locally.
    """
    key = f"evaluation_{evaluation_name}"
    metrics[EVALUATED_ENTRY_METADATA] = {}
    if agent != "":
        metrics[EVALUATED_ENTRY_METADATA]["agent"] = agent
        key += f"_agent_{agent}"
    if model != "":
        metrics[EVALUATED_ENTRY_METADATA]["model"] = model
        key += f"_model_{model}"
    if namespace != "":
        metrics[EVALUATED_ENTRY_METADATA]["namespace"] = namespace
        key += f"_namespace_{namespace}"
    if version != "":
        metrics[EVALUATED_ENTRY_METADATA]["version"] = version
        key += f"_version_{version}"
    if provider != "":
        metrics[EVALUATED_ENTRY_METADATA]["provider"] = provider
        key += f"_provider_{provider}"

    entry_path = get_registry_folder() / key
    # Create folder entry_path if not present
    entry_path.mkdir(parents=True, exist_ok=True)
    # Write file metrics.json inside
    metrics_file = entry_path / "metrics.json"
    with metrics_file.open("w") as f:
        json.dump(metrics, f, indent=2)

    # Get solutions from cache in benchmark.py
    cache = BenchmarkApi().get_benchmark_result_v1_benchmark_get_result_get(benchmark_id)
    solutions = []
    for result in cache:
        try:
            solution = {
                "datum": data_tasks[result.index],
                "status": result.solved,
                "info": load_benchmark_entry_info(result.info) if result.info else {},
            }
            solutions.append(solution)
        except (AttributeError, json.JSONDecodeError, TypeError) as e:
            print(f"Exception while creating solutions data: {str(e)}.")
            # Skip entries that can't be properly formatted
            continue

    # Write solutions file
    solutions_file = entry_path / "solutions.json"
    with solutions_file.open("w") as f:
        json.dump(solutions, f, indent=2)

    metadata_path = entry_path / "metadata.json"
    # TODO(#273): Currently that will not update existing evaluation.
    with open(metadata_path, "w") as f:
        json.dump(
            {
                "name": key,
                "version": "0.1.0",
                "description": "",
                "category": "evaluation",
                "tags": [],
                "details": {},
                "show_entry": True,
            },
            f,
            indent=2,
        )

    registry.upload(Path(entry_path), show_progress=True)

finetune

FinetuneCli

Source code in nearai/finetune/__init__.py
class FinetuneCli:
    def start(
        self,
        model: str,
        tokenizer: str,
        dataset: str,
        num_procs: int,
        format: str,
        upload_checkpoint: bool = True,
        num_nodes: int = 1,
        job_id: Optional[str] = None,
        checkpoint: Optional[str] = None,
        **dataset_kwargs: Any,
    ) -> None:
        """Start a finetuning job on the current node.

        Args:
        ----
            model: Name of a model in the registry. Base model to finetune.
            tokenizer: Name of a tokenizer in the registry. Using tokenizer.model format.
            dataset: Name of a dataset in the registry.
            num_procs: Number of GPUs to use for training
            format: Name of the configuration file to use. For example llama3-70b, llama3-8b. Valid options are in etc/finetune.
            upload_checkpoint: Whether to upload the checkpoint to the registry. Default is True.
            num_nodes: Number of nodes to use for training. Default is 1.
            job_id: Unique identifier for the job. Default is None.
            checkpoint: Name of the model checkpoint to start from. Default is None.
            dataset_kwargs: Additional keyword arguments to pass to the dataset constructor.

        """  # noqa: E501
        from nearai.dataset import get_dataset

        assert num_nodes >= 1

        # Prepare job id folder
        if job_id is None:
            job_id = "job"
        job_id = f"{job_id}-{timestamp()}-{randint(10**8, 10**9 - 1)}"
        job_folder = DATA_FOLDER / "finetune" / job_id
        job_folder.mkdir(parents=True, exist_ok=True)

        # Either use the provided config file template or load one predefined one
        if Path(format).exists():
            config_template_path = Path(format)
        else:
            configs = ETC_FOLDER / "finetune"
            config_template_path = configs / f"{format}.yml"

        if not config_template_path.exists():
            raise FileNotFoundError(f"Config file not found: {config_template_path}")

        CONFIG_TEMPLATE = config_template_path.read_text()  # noqa: N806

        # Download model
        model_path = get_model(model)

        # Download tokenizer
        tokenizer_path = registry.download(tokenizer) / "tokenizer.model"
        assert tokenizer_path.exists(), f"tokenizer.model not found in {tokenizer_path}"

        # Download checkpoint if any
        checkpoint_path = get_model(checkpoint) if checkpoint else "null"
        resume_checkpoint = checkpoint_path != "null"

        # Download dataset
        dataset_path = get_dataset(dataset)

        # Set up output directories
        checkpoint_output_dir = job_folder / "checkpoint_output"
        logging_output_dir = job_folder / "logs"
        logging_output_dir.mkdir(parents=True, exist_ok=True)

        # Prepare config file
        dataset_args_dict = deepcopy(dataset_kwargs)

        dataset_args_dict["_component_"] = dataset_args_dict.pop("method")
        dataset_args_dict["source"] = str(dataset_path.absolute())
        dataset_args = "\n".join(f"  {key}: {value}" for key, value in dataset_args_dict.items())

        config = job_folder / "config.yaml"
        with open(config, "w") as f:
            f.write(
                CONFIG_TEMPLATE.format(
                    TOKENIZER=str(tokenizer_path),
                    MODEL=str(model_path),
                    RECIPE_CHECKPOINT=checkpoint_path,
                    RESUME_FROM_CHECKPOINT=resume_checkpoint,
                    CHECKPOINT_OUTPUT_DIR=str(checkpoint_output_dir),
                    DATASET_ARGS=dataset_args,
                    LOGGING_OUTPUT_DIR=str(logging_output_dir),
                )
            )

        # Spawn background thread to read logs and push to database
        threading.Thread(target=find_new_logs_background, args=(logging_output_dir, job_id)).start()

        print("Starting job at", job_folder)
        if num_nodes == 1:
            run(
                [
                    "tune",
                    "run",
                    "--nproc_per_node",
                    str(num_procs),
                    "lora_finetune_distributed",
                    "--config",
                    str(config),
                ]
            )
        else:
            # Fetch rank and master addr from environment variables
            raise NotImplementedError()

        global BACKGROUND_PROCESS
        BACKGROUND_PROCESS = False

        if upload_checkpoint:
            registry.upload(
                job_folder,
                EntryMetadata.from_dict(
                    {
                        "name": f"finetune-{job_id}",
                        "version": "0.0.1",
                        "description": f"Finetuned checkpoint from base mode {model} using dataset {dataset}",
                        "category": "finetune",
                        "tags": ["finetune", f"base-model-{model}", f"base-dataset-{dataset}"],
                        "details": dict(
                            model=model,
                            tokenizer=tokenizer,
                            dataset=dataset,
                            num_procs=num_procs,
                            format=format,
                            num_nodes=num_nodes,
                            checkpoint=checkpoint,
                            **dataset_kwargs,
                        ),
                        "show_entry": True,
                    }
                ),
                show_progress=True,
            )

    def inspect(self, job_id: str) -> None:  # noqa: D102
        raise NotImplementedError()
start
start(model: str, tokenizer: str, dataset: str, num_procs: int, format: str, upload_checkpoint: bool = True, num_nodes: int = 1, job_id: Optional[str] = None, checkpoint: Optional[str] = None, **dataset_kwargs: Any) -> None

Start a finetuning job on the current node.


model: Name of a model in the registry. Base model to finetune.
tokenizer: Name of a tokenizer in the registry. Using tokenizer.model format.
dataset: Name of a dataset in the registry.
num_procs: Number of GPUs to use for training
format: Name of the configuration file to use. For example llama3-70b, llama3-8b. Valid options are in etc/finetune.
upload_checkpoint: Whether to upload the checkpoint to the registry. Default is True.
num_nodes: Number of nodes to use for training. Default is 1.
job_id: Unique identifier for the job. Default is None.
checkpoint: Name of the model checkpoint to start from. Default is None.
dataset_kwargs: Additional keyword arguments to pass to the dataset constructor.
Source code in nearai/finetune/__init__.py
def start(
    self,
    model: str,
    tokenizer: str,
    dataset: str,
    num_procs: int,
    format: str,
    upload_checkpoint: bool = True,
    num_nodes: int = 1,
    job_id: Optional[str] = None,
    checkpoint: Optional[str] = None,
    **dataset_kwargs: Any,
) -> None:
    """Start a finetuning job on the current node.

    Args:
    ----
        model: Name of a model in the registry. Base model to finetune.
        tokenizer: Name of a tokenizer in the registry. Using tokenizer.model format.
        dataset: Name of a dataset in the registry.
        num_procs: Number of GPUs to use for training
        format: Name of the configuration file to use. For example llama3-70b, llama3-8b. Valid options are in etc/finetune.
        upload_checkpoint: Whether to upload the checkpoint to the registry. Default is True.
        num_nodes: Number of nodes to use for training. Default is 1.
        job_id: Unique identifier for the job. Default is None.
        checkpoint: Name of the model checkpoint to start from. Default is None.
        dataset_kwargs: Additional keyword arguments to pass to the dataset constructor.

    """  # noqa: E501
    from nearai.dataset import get_dataset

    assert num_nodes >= 1

    # Prepare job id folder
    if job_id is None:
        job_id = "job"
    job_id = f"{job_id}-{timestamp()}-{randint(10**8, 10**9 - 1)}"
    job_folder = DATA_FOLDER / "finetune" / job_id
    job_folder.mkdir(parents=True, exist_ok=True)

    # Either use the provided config file template or load one predefined one
    if Path(format).exists():
        config_template_path = Path(format)
    else:
        configs = ETC_FOLDER / "finetune"
        config_template_path = configs / f"{format}.yml"

    if not config_template_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_template_path}")

    CONFIG_TEMPLATE = config_template_path.read_text()  # noqa: N806

    # Download model
    model_path = get_model(model)

    # Download tokenizer
    tokenizer_path = registry.download(tokenizer) / "tokenizer.model"
    assert tokenizer_path.exists(), f"tokenizer.model not found in {tokenizer_path}"

    # Download checkpoint if any
    checkpoint_path = get_model(checkpoint) if checkpoint else "null"
    resume_checkpoint = checkpoint_path != "null"

    # Download dataset
    dataset_path = get_dataset(dataset)

    # Set up output directories
    checkpoint_output_dir = job_folder / "checkpoint_output"
    logging_output_dir = job_folder / "logs"
    logging_output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare config file
    dataset_args_dict = deepcopy(dataset_kwargs)

    dataset_args_dict["_component_"] = dataset_args_dict.pop("method")
    dataset_args_dict["source"] = str(dataset_path.absolute())
    dataset_args = "\n".join(f"  {key}: {value}" for key, value in dataset_args_dict.items())

    config = job_folder / "config.yaml"
    with open(config, "w") as f:
        f.write(
            CONFIG_TEMPLATE.format(
                TOKENIZER=str(tokenizer_path),
                MODEL=str(model_path),
                RECIPE_CHECKPOINT=checkpoint_path,
                RESUME_FROM_CHECKPOINT=resume_checkpoint,
                CHECKPOINT_OUTPUT_DIR=str(checkpoint_output_dir),
                DATASET_ARGS=dataset_args,
                LOGGING_OUTPUT_DIR=str(logging_output_dir),
            )
        )

    # Spawn background thread to read logs and push to database
    threading.Thread(target=find_new_logs_background, args=(logging_output_dir, job_id)).start()

    print("Starting job at", job_folder)
    if num_nodes == 1:
        run(
            [
                "tune",
                "run",
                "--nproc_per_node",
                str(num_procs),
                "lora_finetune_distributed",
                "--config",
                str(config),
            ]
        )
    else:
        # Fetch rank and master addr from environment variables
        raise NotImplementedError()

    global BACKGROUND_PROCESS
    BACKGROUND_PROCESS = False

    if upload_checkpoint:
        registry.upload(
            job_folder,
            EntryMetadata.from_dict(
                {
                    "name": f"finetune-{job_id}",
                    "version": "0.0.1",
                    "description": f"Finetuned checkpoint from base mode {model} using dataset {dataset}",
                    "category": "finetune",
                    "tags": ["finetune", f"base-model-{model}", f"base-dataset-{dataset}"],
                    "details": dict(
                        model=model,
                        tokenizer=tokenizer,
                        dataset=dataset,
                        num_procs=num_procs,
                        format=format,
                        num_nodes=num_nodes,
                        checkpoint=checkpoint,
                        **dataset_kwargs,
                    ),
                    "show_entry": True,
                }
            ),
            show_progress=True,
        )

parse_line

parse_line(line: str) -> Tuple[int, dict[str, float]]

Example of line to be parsed.

Step 33 | loss:1.5400923490524292 lr:9.9e-05 tokens_per_second_per_gpu:101.22285588141214

Source code in nearai/finetune/__init__.py
def parse_line(line: str) -> Tuple[int, dict[str, float]]:
    """Example of line to be parsed.

    Step 33 | loss:1.5400923490524292 lr:9.9e-05 tokens_per_second_per_gpu:101.22285588141214
    """
    step_raw, metrics_raw = map(str.strip, line.strip(" \n").split("|"))
    step = int(step_raw.split(" ")[-1])
    metrics = {metric[0]: float(metric[1]) for metric in map(lambda metric: metric.split(":"), metrics_raw.split(" "))}
    return step, metrics

text_completion

TextCompletionDataset

Bases: Dataset

Freeform dataset for any unstructured text corpus. Quickly load any dataset from Hugging Face or local disk and tokenize it for your model.


tokenizer (BaseTokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
    (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (str): name of column in the sample that contains the text data. This is typically required
    for Hugging Face datasets or tabular data. For local datasets with a single column, use the default "text",
    which is what is assigned by Hugging Face datasets when loaded into memory. Default is "text".
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
    Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
    and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Source code in nearai/finetune/text_completion.py
class TextCompletionDataset(Dataset):
    """Freeform dataset for any unstructured text corpus. Quickly load any dataset from Hugging Face or local disk and tokenize it for your model.

    Args:
    ----
        tokenizer (BaseTokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method.
        source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
            (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
        column (str): name of column in the sample that contains the text data. This is typically required
            for Hugging Face datasets or tabular data. For local datasets with a single column, use the default "text",
            which is what is assigned by Hugging Face datasets when loaded into memory. Default is "text".
        max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
            Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
            and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
        **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.

    """  # noqa: E501

    def __init__(  # noqa: D107
        self,
        tokenizer: BaseTokenizer,
        source: str,
        column: str = "text",
        split: Optional[str] = None,
        max_seq_len: Optional[int] = None,
        **load_dataset_kwargs: Dict[str, Any],
    ) -> None:
        self._tokenizer = tokenizer
        self._data = load_from_disk(source, **load_dataset_kwargs)
        if split is not None:
            self._data = self._data[split]
        self.max_seq_len = max_seq_len
        self._column = column

    def __len__(self) -> int:  # noqa: D105
        return len(self._data)

    def __getitem__(self, index: int) -> Dict[str, List[int]]:  # noqa: D105
        sample = self._data[index]
        return self._prepare_sample(sample)

    def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
        prompt = sample[self._column]
        tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=True)

        # Truncate if needed, but don't coerce EOS id
        if self.max_seq_len is not None:
            tokens = truncate(tokens, self.max_seq_len - 1)

        # No need to offset labels by 1 - happens in the recipe
        labels = tokens.copy()

        return {"tokens": tokens, "labels": labels}
truncate
truncate(tokens: List[Any], max_seq_len: int, eos_id: Optional[Any] = None) -> List[Any]

Truncate a list of tokens to a maximum length. If eos_id is provided, the last token will be replaced with eos_id.


tokens (List[Any]): list of tokens to truncate
max_seq_len (int): maximum length of the list
eos_id (Optional[Any]): token to replace the last token with. If None, the
    last token will not be replaced. Default is None.

List[Any]: truncated list of tokens
Source code in nearai/finetune/text_completion.py
def truncate(
    tokens: List[Any],
    max_seq_len: int,
    eos_id: Optional[Any] = None,
) -> List[Any]:
    """Truncate a list of tokens to a maximum length. If eos_id is provided, the last token will be replaced with eos_id.

    Args:
    ----
        tokens (List[Any]): list of tokens to truncate
        max_seq_len (int): maximum length of the list
        eos_id (Optional[Any]): token to replace the last token with. If None, the
            last token will not be replaced. Default is None.

    Returns:
    -------
        List[Any]: truncated list of tokens

    """  # noqa: E501
    tokens_truncated = tokens[:max_seq_len]
    if eos_id is not None and tokens_truncated[-1] != eos_id:
        tokens_truncated[-1] = eos_id
    return tokens_truncated

hub

Hub

Bases: object

Source code in nearai/hub.py
class Hub(object):
    def __init__(self, config: Config) -> None:
        """Initializes the Hub class with the given configuration."""
        self.info = None
        self.provider = None
        self.model = None
        self.endpoint = None
        self.query = None
        self._config = config

    def parse_hub_chat_params(self, kwargs):
        """Parses and sets instance attributes from the given keyword arguments, using default values if needed."""
        if self._config.nearai_hub is None:
            self._config.nearai_hub = NearAiHubConfig()

        self.query = kwargs.get("query")
        self.endpoint = kwargs.get("endpoint", f"{self._config.nearai_hub.base_url}/chat/completions")
        self.model = kwargs.get("model", self._config.nearai_hub.default_model)
        self.provider = kwargs.get("provider", self._config.nearai_hub.default_provider)
        self.info = kwargs.get("info", False)

    def chat(self, kwargs):
        """Processes a chat request by sending parameters to the NEAR AI Hub and printing the response."""
        try:
            self.parse_hub_chat_params(kwargs)

            if not self.query:
                return print("Error: 'query' is required for the `hub chat` command.")

            if self._config.nearai_hub is None:
                self._config.nearai_hub = NearAiHubConfig()

            data = {
                "max_tokens": 256,
                "temperature": 1,
                "frequency_penalty": 0,
                "n": 1,
                "messages": [{"role": "user", "content": str(self.query)}],
                "model": self.model,
            }

            auth = self._config.auth

            if self._config.nearai_hub.login_with_near:
                bearer_token = auth.generate_bearer_token()
                headers = {"Content-Type": "application/json", "Authorization": f"Bearer {bearer_token}"}

                data["provider"] = self.provider
            elif self._config.nearai_hub.api_key:
                headers = {
                    "Content-Type": "application/json",
                    "Authorization": "Bearer {}".format(self._config.nearai_hub.api_key),
                }
            else:
                return print("Illegal NEAR AI Hub Config")

            if self.info:
                print(f"Requesting hub using NEAR Account {auth.account_id}")

            response = requests.post(self.endpoint, headers=headers, data=json.dumps(data))

            completion = response.json()

            print(completion["choices"][0]["message"]["content"])

        except Exception as e:
            print(f"Request failed: {e}")
__init__
__init__(config: Config) -> None

Initializes the Hub class with the given configuration.

Source code in nearai/hub.py
def __init__(self, config: Config) -> None:
    """Initializes the Hub class with the given configuration."""
    self.info = None
    self.provider = None
    self.model = None
    self.endpoint = None
    self.query = None
    self._config = config
chat
chat(kwargs)

Processes a chat request by sending parameters to the NEAR AI Hub and printing the response.

Source code in nearai/hub.py
def chat(self, kwargs):
    """Processes a chat request by sending parameters to the NEAR AI Hub and printing the response."""
    try:
        self.parse_hub_chat_params(kwargs)

        if not self.query:
            return print("Error: 'query' is required for the `hub chat` command.")

        if self._config.nearai_hub is None:
            self._config.nearai_hub = NearAiHubConfig()

        data = {
            "max_tokens": 256,
            "temperature": 1,
            "frequency_penalty": 0,
            "n": 1,
            "messages": [{"role": "user", "content": str(self.query)}],
            "model": self.model,
        }

        auth = self._config.auth

        if self._config.nearai_hub.login_with_near:
            bearer_token = auth.generate_bearer_token()
            headers = {"Content-Type": "application/json", "Authorization": f"Bearer {bearer_token}"}

            data["provider"] = self.provider
        elif self._config.nearai_hub.api_key:
            headers = {
                "Content-Type": "application/json",
                "Authorization": "Bearer {}".format(self._config.nearai_hub.api_key),
            }
        else:
            return print("Illegal NEAR AI Hub Config")

        if self.info:
            print(f"Requesting hub using NEAR Account {auth.account_id}")

        response = requests.post(self.endpoint, headers=headers, data=json.dumps(data))

        completion = response.json()

        print(completion["choices"][0]["message"]["content"])

    except Exception as e:
        print(f"Request failed: {e}")
parse_hub_chat_params
parse_hub_chat_params(kwargs)

Parses and sets instance attributes from the given keyword arguments, using default values if needed.

Source code in nearai/hub.py
def parse_hub_chat_params(self, kwargs):
    """Parses and sets instance attributes from the given keyword arguments, using default values if needed."""
    if self._config.nearai_hub is None:
        self._config.nearai_hub = NearAiHubConfig()

    self.query = kwargs.get("query")
    self.endpoint = kwargs.get("endpoint", f"{self._config.nearai_hub.base_url}/chat/completions")
    self.model = kwargs.get("model", self._config.nearai_hub.default_model)
    self.provider = kwargs.get("provider", self._config.nearai_hub.default_provider)
    self.info = kwargs.get("info", False)

lib

parse_location

parse_location(entry_location: str) -> EntryLocation

Create a EntryLocation from a string in the format namespace/name/version.

Source code in nearai/lib.py
def parse_location(entry_location: str) -> EntryLocation:
    """Create a EntryLocation from a string in the format namespace/name/version."""
    match = entry_location_pattern.match(entry_location)

    if match is None:
        raise ValueError(f"Invalid entry format: {entry_location}. Should have the format <namespace>/<name>/<version>")

    return EntryLocation(
        namespace=match.group("namespace"),
        name=match.group("name"),
        version=match.group("version"),
    )

login

AuthHandler

Bases: SimpleHTTPRequestHandler

Source code in nearai/login.py
class AuthHandler(http.server.SimpleHTTPRequestHandler):
    def log_message(self, format, *args):
        """Webserver logging method."""
        pass  # Override to suppress logging

    def do_GET(self):  # noqa: N802
        """Webserver GET method."""
        global NONCE, PORT

        script_path = Path(__file__).resolve()
        assets_folder = script_path.parent / "assets"

        if self.path.startswith("/capture"):
            with open(os.path.join(assets_folder, "auth_capture.html"), "r", encoding="utf-8") as file:
                content = file.read()
            self.send_response(200)
            self.send_header("Content-type", "text/html")
            self.end_headers()
            self.wfile.write(content.encode("utf-8"))

        if self.path.startswith("/auth"):
            parsed_url = urlparse.urlparse(self.path)
            fragment = parsed_url.query
            params = urlparse.parse_qs(fragment)

            required_params = ["accountId", "signature", "publicKey"]

            if all(param in params for param in required_params):
                update_auth_config(
                    params["accountId"][0],
                    params["signature"][0],
                    params["publicKey"][0],
                    callback_url=generate_callback_url(PORT),
                    nonce=NONCE,
                )
            else:
                print("Required parameters not found")

            with open(os.path.join(assets_folder, "auth_complete.html"), "r", encoding="utf-8") as file:
                content = file.read()
            self.send_response(200)
            self.send_header("Content-type", "text/html")
            self.end_headers()
            self.wfile.write(content.encode("utf-8"))

            # Give the server some time to read the response before shutting it down
            def shutdown_server():
                global httpd
                time.sleep(2)  # Wait 2 seconds before shutting down
                if httpd:
                    httpd.shutdown()

            threading.Thread(target=shutdown_server).start()
do_GET
do_GET()

Webserver GET method.

Source code in nearai/login.py
def do_GET(self):  # noqa: N802
    """Webserver GET method."""
    global NONCE, PORT

    script_path = Path(__file__).resolve()
    assets_folder = script_path.parent / "assets"

    if self.path.startswith("/capture"):
        with open(os.path.join(assets_folder, "auth_capture.html"), "r", encoding="utf-8") as file:
            content = file.read()
        self.send_response(200)
        self.send_header("Content-type", "text/html")
        self.end_headers()
        self.wfile.write(content.encode("utf-8"))

    if self.path.startswith("/auth"):
        parsed_url = urlparse.urlparse(self.path)
        fragment = parsed_url.query
        params = urlparse.parse_qs(fragment)

        required_params = ["accountId", "signature", "publicKey"]

        if all(param in params for param in required_params):
            update_auth_config(
                params["accountId"][0],
                params["signature"][0],
                params["publicKey"][0],
                callback_url=generate_callback_url(PORT),
                nonce=NONCE,
            )
        else:
            print("Required parameters not found")

        with open(os.path.join(assets_folder, "auth_complete.html"), "r", encoding="utf-8") as file:
            content = file.read()
        self.send_response(200)
        self.send_header("Content-type", "text/html")
        self.end_headers()
        self.wfile.write(content.encode("utf-8"))

        # Give the server some time to read the response before shutting it down
        def shutdown_server():
            global httpd
            time.sleep(2)  # Wait 2 seconds before shutting down
            if httpd:
                httpd.shutdown()

        threading.Thread(target=shutdown_server).start()
log_message
log_message(format, *args)

Webserver logging method.

Source code in nearai/login.py
def log_message(self, format, *args):
    """Webserver logging method."""
    pass  # Override to suppress logging

find_open_port

find_open_port() -> int

Finds and returns an open port number by binding to a free port on the local machine.

Source code in nearai/login.py
def find_open_port() -> int:
    """Finds and returns an open port number by binding to a free port on the local machine."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        return s.getsockname()[1]

generate_and_save_signature

generate_and_save_signature(account_id, private_key)

Generates a signature for the given account ID and private key, then updates the auth configuration.

Source code in nearai/login.py
def generate_and_save_signature(account_id, private_key):
    """Generates a signature for the given account ID and private key, then updates the auth configuration."""
    nonce = generate_nonce()
    payload = near.Payload(MESSAGE, nonce, RECIPIENT, None)

    signature, public_key = near.create_signature(private_key, payload)

    if update_auth_config(account_id, signature, public_key, None, nonce):
        print_login_status()

generate_callback_url

generate_callback_url(port)

Generates a callback URL using the specified port number.

Source code in nearai/login.py
def generate_callback_url(port):
    """Generates a callback URL using the specified port number."""
    return f"http://localhost:{port}/capture"

generate_nonce

generate_nonce()

Generates a nonce based on the current time in milliseconds.

Source code in nearai/login.py
def generate_nonce():
    """Generates a nonce based on the current time in milliseconds."""
    return str(int(time.time() * 1000))

login_with_file_credentials

login_with_file_credentials(account_id)

Logs in using credentials from a file for the specified account ID, generating and saving a signature.

Source code in nearai/login.py
def login_with_file_credentials(account_id):
    """Logs in using credentials from a file for the specified account ID, generating and saving a signature."""
    file_path = os.path.expanduser(os.path.join("~/.near-credentials/", "mainnet", f"{account_id}.json"))

    if os.path.exists(file_path):
        with open(file_path, "r") as file:
            content = file.read()
            account_data = json.loads(content)
            private_key = account_data.get("private_key", None)
            if not private_key:
                return print(f"Private key is missing for {account_id} on mainnet")
            generate_and_save_signature(account_id, account_data["private_key"])

    else:
        return print(f"Account data is missing for {account_id}")

login_with_near_auth

login_with_near_auth(remote, auth_url)

Initiates the login process using NEAR authentication, either starting a local server to handle the callback or providing a URL for remote authentication.

Source code in nearai/login.py
def login_with_near_auth(remote, auth_url):
    """Initiates the login process using NEAR authentication, either starting a local server to handle the callback or providing a URL for remote authentication."""  # noqa: E501
    global NONCE, PORT
    NONCE = generate_nonce()

    params = {
        "message": MESSAGE,
        "nonce": NONCE,
        "recipient": RECIPIENT,
    }

    if not remote:
        PORT = find_open_port()

        global httpd
        with socketserver.TCPServer(("", PORT), AuthHandler) as httpd:
            params["callbackUrl"] = f"http://localhost:{PORT}/capture"

            encoded_params = urlparse.urlencode(params)

            print_url_message(f"{auth_url}?{encoded_params}")

            httpd.serve_forever()

    else:
        encoded_params = urlparse.urlencode(params)

        print_url_message(f"{auth_url}?{encoded_params}")
        print("After visiting the URL, follow the instructions to save your auth signature")

print_login_status

print_login_status()

Prints the current authentication status if available in the config file.

Source code in nearai/login.py
def print_login_status():
    """Prints the current authentication status if available in the config file."""
    config = load_config_file()
    if config.get("auth") and config["auth"].get("account_id"):
        print(f'Auth data for: {config["auth"]["account_id"]}')
        print(f'signature: {config["auth"]["signature"]}')
        print(f'public_key: {config["auth"]["public_key"]}')
        print(f'nonce: {config["auth"]["nonce"]}')
        print(f'message: {config["auth"]["message"]}')
        print(f'recipient: {config["auth"]["recipient"]}')
    else:
        print("Near auth details not found")

print_url_message

print_url_message(url)

Prints a message instructing the user to visit the given URL to complete the login process.

Source code in nearai/login.py
def print_url_message(url):
    """Prints a message instructing the user to visit the given URL to complete the login process."""
    print(f"Please visit the following URL to complete the login process: {url}")

update_auth_config

update_auth_config(account_id, signature, public_key, callback_url, nonce)

Update authentication configuration if the provided signature is valid.

Source code in nearai/login.py
def update_auth_config(account_id, signature, public_key, callback_url, nonce):
    """Update authentication configuration if the provided signature is valid."""
    if near.verify_signed_message(
        account_id,
        public_key,
        signature,
        MESSAGE,
        nonce,
        RECIPIENT,
        callback_url,
    ):
        config = load_config_file()

        auth = AuthData.model_validate(
            {
                "account_id": account_id,
                "signature": signature,
                "public_key": public_key,
                "callback_url": callback_url,
                "nonce": nonce,
                "recipient": RECIPIENT,
                "message": MESSAGE,
            }
        )

        config["auth"] = auth.model_dump()
        save_config_file(config)

        print(f"Auth data has been successfully saved! You are now logged in with account ID: {account_id}")
        return True
    else:
        print("Signature verification failed. Abort")
        return False

model

get_model

get_model(name: str) -> Path

Download the model from the registry and download it locally if it hasn't been downloaded yet.

:param name: The name of the entry to download the model. The format should be namespace/name/version. :return: The path to the downloaded model

Source code in nearai/model.py
def get_model(name: str) -> Path:
    """Download the model from the registry and download it locally if it hasn't been downloaded yet.

    :param name: The name of the entry to download the model. The format should be namespace/name/version.
    :return: The path to the downloaded model
    """
    return registry.download(name)

registry

Registry

Source code in nearai/registry.py
class Registry:
    def __init__(self):
        """Create Registry object to interact with the registry programmatically."""
        self.download_folder = DATA_FOLDER / "registry"
        self.api = RegistryApi()

        if not self.download_folder.exists():
            self.download_folder.mkdir(parents=True, exist_ok=True)

    def update(self, entry_location: EntryLocation, metadata: EntryMetadataInput) -> Dict[str, Any]:
        """Update metadata of a entry in the registry."""
        result = self.api.upload_metadata_v1_registry_upload_metadata_post(
            BodyUploadMetadataV1RegistryUploadMetadataPost(metadata=metadata, entry_location=entry_location)
        )
        return result

    def info(self, entry_location: EntryLocation) -> Optional[EntryMetadata]:
        """Get metadata of a entry in the registry."""
        try:
            return self.api.download_metadata_v1_registry_download_metadata_post(
                BodyDownloadMetadataV1RegistryDownloadMetadataPost.from_dict(dict(entry_location=entry_location))
            )
        except NotFoundException:
            return None

    def upload_file(self, entry_location: EntryLocation, local_path: Path, path: Path) -> bool:
        """Upload a file to the registry."""
        with open(local_path, "rb") as file:
            data = file.read()

            try:
                self.api.upload_file_v1_registry_upload_file_post(
                    path=str(path),
                    file=data,
                    namespace=entry_location.namespace,
                    name=entry_location.name,
                    version=entry_location.version,
                )
                return True
            except BadRequestException as e:
                if isinstance(e.body, str) and "already exists" in e.body:
                    return False

                raise e

    def download_file(self, entry_location: EntryLocation, path: Path, local_path: Path):
        """Download a file from the registry."""
        result = self.api.download_file_v1_registry_download_file_post_without_preload_content(
            BodyDownloadFileV1RegistryDownloadFilePost.from_dict(
                dict(
                    entry_location=entry_location,
                    path=str(path),
                )
            )
        )

        local_path.parent.mkdir(parents=True, exist_ok=True)

        with open(local_path, "wb") as f:
            copyfileobj(result, f)

    def download(
        self,
        entry_location: Union[str, EntryLocation],
        force: bool = False,
        show_progress: bool = False,
        verbose: bool = True,
    ) -> Path:
        """Download entry from the registry locally."""
        if isinstance(entry_location, str):
            entry_location = parse_location(entry_location)

        download_path = get_registry_folder() / entry_location.namespace / entry_location.name / entry_location.version

        if download_path.exists():
            if not force:
                if verbose:
                    print(
                        f"Entry {entry_location} already exists at {download_path}. Use --force to overwrite the entry."
                    )
                return download_path

        files = registry.list_files(entry_location)

        download_path.mkdir(parents=True, exist_ok=True)

        metadata = registry.info(entry_location)

        if metadata is None:
            raise ValueError(f"Entry {entry_location} not found.")

        metadata_path = download_path / "metadata.json"
        with open(metadata_path, "w") as f:
            f.write(metadata.model_dump_json(indent=2))

        for file in (pbar := tqdm(files, disable=not show_progress)):
            pbar.set_description(file)
            registry.download_file(entry_location, file, download_path / file)

        return download_path

    def upload(
        self,
        local_path: Path,
        metadata: Optional[EntryMetadata] = None,
        show_progress: bool = False,
    ) -> EntryLocation:
        """Upload entry to the registry.

        If metadata is provided it will overwrite the metadata in the directory,
        otherwise it will use the metadata.json found on the root of the directory.
        """
        path = Path(local_path).absolute()

        if not path.exists():
            # try path in local registry if original path not exists
            path = get_registry_folder() / local_path

        if CONFIG.auth is None:
            print("Please login with `nearai login`")
            exit(1)

        metadata_path = path / "metadata.json"

        if metadata is not None:
            with open(metadata_path, "w") as f:
                f.write(metadata.model_dump_json(indent=2))

        check_metadata(metadata_path)

        with open(metadata_path) as f:
            plain_metadata: Dict[str, Any] = json.load(f)

        namespace = get_namespace(local_path)
        name = plain_metadata.pop("name")

        entry_location = EntryLocation.model_validate(
            dict(
                namespace=namespace,
                name=name,
                version=plain_metadata.pop("version"),
            )
        )

        entry_metadata = EntryMetadataInput.model_validate(plain_metadata)
        source = entry_metadata.details.get("_source", None)

        if source is not None:
            print(f"Only default source is allowed, found: {source}. Remove details._source from metadata.")
            exit(1)

        if self.info(entry_location) is None:
            # New entry location. Check for similar names in registry.
            entries = self.list_all_visible()
            canonical_namespace = get_canonical_name(namespace)
            canonical_name = get_canonical_name(name)

            for entry in entries:
                if entry.name == name and entry.namespace == namespace:
                    break
                if (
                    get_canonical_name(entry.name) == canonical_name
                    and get_canonical_name(entry.namespace) == canonical_namespace
                ):
                    print(f"A registry item with a similar name already exists: {entry.namespace}/{entry.name}")
                    exit(1)

        registry.update(entry_location, entry_metadata)

        all_files = []
        total_size = 0

        # Traverse all files in the directory `path`
        for file in path.rglob("*"):
            if not file.is_file():
                continue

            relative = file.relative_to(path)

            # Don't upload metadata file.
            if file == metadata_path:
                continue

            # Don't upload backup files.
            if file.name.endswith("~"):
                continue

            # Don't upload configuration files.
            if relative.parts[0] == ".nearai":
                continue

            size = file.stat().st_size
            total_size += size

            all_files.append((file, relative, size))

        pbar = tqdm(total=total_size, unit="B", unit_scale=True, disable=not show_progress)
        for file, relative, size in all_files:
            registry.upload_file(entry_location, file, relative)
            pbar.update(size)

        return entry_location

    def list_files(self, entry_location: EntryLocation) -> List[str]:
        """List files in from an entry in the registry.

        Return the relative paths to all files with respect to the root of the entry.
        """
        result = self.api.list_files_v1_registry_list_files_post(
            BodyListFilesV1RegistryListFilesPost.from_dict(dict(entry_location=entry_location))
        )
        return [file.filename for file in result]

    def list(
        self,
        namespace: str,
        category: str,
        tags: str,
        total: int,
        offset: int,
        show_all: bool,
        show_latest_version: bool,
        starred_by: str = "",
    ) -> List[EntryInformation]:
        """List and filter entries in the registry."""
        return self.api.list_entries_v1_registry_list_entries_post(
            namespace=namespace,
            category=category,
            tags=tags,
            total=total,
            offset=offset,
            show_hidden=show_all,
            show_latest_version=show_latest_version,
            starred_by=starred_by,
        )

    def list_all_visible(self, category: str = "") -> List[EntryInformation]:
        """List all visible entries."""
        total = 10000
        entries = self.list(
            namespace="",
            category=category,
            tags="",
            total=total,
            offset=0,
            show_all=False,
            show_latest_version=True,
        )
        assert len(entries) < total
        return entries

    def dict_models(self) -> Dict[NamespacedName, NamespacedName]:
        """Returns a mapping canonical->name."""
        entries = self.list_all_visible(category="model")
        result: Dict[NamespacedName, NamespacedName] = {}
        for entry in entries:
            namespaced_name = NamespacedName(name=entry.name, namespace=entry.namespace)
            canonical_namespaced_name = namespaced_name.canonical()
            if canonical_namespaced_name in result:
                raise ValueError(
                    f"Duplicate registry entry for model {namespaced_name}, canonical {canonical_namespaced_name}"
                )
            result[canonical_namespaced_name] = namespaced_name
        return result
__init__
__init__()

Create Registry object to interact with the registry programmatically.

Source code in nearai/registry.py
def __init__(self):
    """Create Registry object to interact with the registry programmatically."""
    self.download_folder = DATA_FOLDER / "registry"
    self.api = RegistryApi()

    if not self.download_folder.exists():
        self.download_folder.mkdir(parents=True, exist_ok=True)
dict_models
dict_models() -> Dict[NamespacedName, NamespacedName]

Returns a mapping canonical->name.

Source code in nearai/registry.py
def dict_models(self) -> Dict[NamespacedName, NamespacedName]:
    """Returns a mapping canonical->name."""
    entries = self.list_all_visible(category="model")
    result: Dict[NamespacedName, NamespacedName] = {}
    for entry in entries:
        namespaced_name = NamespacedName(name=entry.name, namespace=entry.namespace)
        canonical_namespaced_name = namespaced_name.canonical()
        if canonical_namespaced_name in result:
            raise ValueError(
                f"Duplicate registry entry for model {namespaced_name}, canonical {canonical_namespaced_name}"
            )
        result[canonical_namespaced_name] = namespaced_name
    return result
download
download(entry_location: Union[str, EntryLocation], force: bool = False, show_progress: bool = False, verbose: bool = True) -> Path

Download entry from the registry locally.

Source code in nearai/registry.py
def download(
    self,
    entry_location: Union[str, EntryLocation],
    force: bool = False,
    show_progress: bool = False,
    verbose: bool = True,
) -> Path:
    """Download entry from the registry locally."""
    if isinstance(entry_location, str):
        entry_location = parse_location(entry_location)

    download_path = get_registry_folder() / entry_location.namespace / entry_location.name / entry_location.version

    if download_path.exists():
        if not force:
            if verbose:
                print(
                    f"Entry {entry_location} already exists at {download_path}. Use --force to overwrite the entry."
                )
            return download_path

    files = registry.list_files(entry_location)

    download_path.mkdir(parents=True, exist_ok=True)

    metadata = registry.info(entry_location)

    if metadata is None:
        raise ValueError(f"Entry {entry_location} not found.")

    metadata_path = download_path / "metadata.json"
    with open(metadata_path, "w") as f:
        f.write(metadata.model_dump_json(indent=2))

    for file in (pbar := tqdm(files, disable=not show_progress)):
        pbar.set_description(file)
        registry.download_file(entry_location, file, download_path / file)

    return download_path
download_file
download_file(entry_location: EntryLocation, path: Path, local_path: Path)

Download a file from the registry.

Source code in nearai/registry.py
def download_file(self, entry_location: EntryLocation, path: Path, local_path: Path):
    """Download a file from the registry."""
    result = self.api.download_file_v1_registry_download_file_post_without_preload_content(
        BodyDownloadFileV1RegistryDownloadFilePost.from_dict(
            dict(
                entry_location=entry_location,
                path=str(path),
            )
        )
    )

    local_path.parent.mkdir(parents=True, exist_ok=True)

    with open(local_path, "wb") as f:
        copyfileobj(result, f)
info
info(entry_location: EntryLocation) -> Optional[EntryMetadata]

Get metadata of a entry in the registry.

Source code in nearai/registry.py
def info(self, entry_location: EntryLocation) -> Optional[EntryMetadata]:
    """Get metadata of a entry in the registry."""
    try:
        return self.api.download_metadata_v1_registry_download_metadata_post(
            BodyDownloadMetadataV1RegistryDownloadMetadataPost.from_dict(dict(entry_location=entry_location))
        )
    except NotFoundException:
        return None
list
list(namespace: str, category: str, tags: str, total: int, offset: int, show_all: bool, show_latest_version: bool, starred_by: str = '') -> List[EntryInformation]

List and filter entries in the registry.

Source code in nearai/registry.py
def list(
    self,
    namespace: str,
    category: str,
    tags: str,
    total: int,
    offset: int,
    show_all: bool,
    show_latest_version: bool,
    starred_by: str = "",
) -> List[EntryInformation]:
    """List and filter entries in the registry."""
    return self.api.list_entries_v1_registry_list_entries_post(
        namespace=namespace,
        category=category,
        tags=tags,
        total=total,
        offset=offset,
        show_hidden=show_all,
        show_latest_version=show_latest_version,
        starred_by=starred_by,
    )
list_all_visible
list_all_visible(category: str = '') -> List[EntryInformation]

List all visible entries.

Source code in nearai/registry.py
def list_all_visible(self, category: str = "") -> List[EntryInformation]:
    """List all visible entries."""
    total = 10000
    entries = self.list(
        namespace="",
        category=category,
        tags="",
        total=total,
        offset=0,
        show_all=False,
        show_latest_version=True,
    )
    assert len(entries) < total
    return entries
list_files
list_files(entry_location: EntryLocation) -> List[str]

List files in from an entry in the registry.

Return the relative paths to all files with respect to the root of the entry.

Source code in nearai/registry.py
def list_files(self, entry_location: EntryLocation) -> List[str]:
    """List files in from an entry in the registry.

    Return the relative paths to all files with respect to the root of the entry.
    """
    result = self.api.list_files_v1_registry_list_files_post(
        BodyListFilesV1RegistryListFilesPost.from_dict(dict(entry_location=entry_location))
    )
    return [file.filename for file in result]
update
update(entry_location: EntryLocation, metadata: EntryMetadataInput) -> Dict[str, Any]

Update metadata of a entry in the registry.

Source code in nearai/registry.py
def update(self, entry_location: EntryLocation, metadata: EntryMetadataInput) -> Dict[str, Any]:
    """Update metadata of a entry in the registry."""
    result = self.api.upload_metadata_v1_registry_upload_metadata_post(
        BodyUploadMetadataV1RegistryUploadMetadataPost(metadata=metadata, entry_location=entry_location)
    )
    return result
upload
upload(local_path: Path, metadata: Optional[EntryMetadata] = None, show_progress: bool = False) -> EntryLocation

Upload entry to the registry.

If metadata is provided it will overwrite the metadata in the directory, otherwise it will use the metadata.json found on the root of the directory.

Source code in nearai/registry.py
def upload(
    self,
    local_path: Path,
    metadata: Optional[EntryMetadata] = None,
    show_progress: bool = False,
) -> EntryLocation:
    """Upload entry to the registry.

    If metadata is provided it will overwrite the metadata in the directory,
    otherwise it will use the metadata.json found on the root of the directory.
    """
    path = Path(local_path).absolute()

    if not path.exists():
        # try path in local registry if original path not exists
        path = get_registry_folder() / local_path

    if CONFIG.auth is None:
        print("Please login with `nearai login`")
        exit(1)

    metadata_path = path / "metadata.json"

    if metadata is not None:
        with open(metadata_path, "w") as f:
            f.write(metadata.model_dump_json(indent=2))

    check_metadata(metadata_path)

    with open(metadata_path) as f:
        plain_metadata: Dict[str, Any] = json.load(f)

    namespace = get_namespace(local_path)
    name = plain_metadata.pop("name")

    entry_location = EntryLocation.model_validate(
        dict(
            namespace=namespace,
            name=name,
            version=plain_metadata.pop("version"),
        )
    )

    entry_metadata = EntryMetadataInput.model_validate(plain_metadata)
    source = entry_metadata.details.get("_source", None)

    if source is not None:
        print(f"Only default source is allowed, found: {source}. Remove details._source from metadata.")
        exit(1)

    if self.info(entry_location) is None:
        # New entry location. Check for similar names in registry.
        entries = self.list_all_visible()
        canonical_namespace = get_canonical_name(namespace)
        canonical_name = get_canonical_name(name)

        for entry in entries:
            if entry.name == name and entry.namespace == namespace:
                break
            if (
                get_canonical_name(entry.name) == canonical_name
                and get_canonical_name(entry.namespace) == canonical_namespace
            ):
                print(f"A registry item with a similar name already exists: {entry.namespace}/{entry.name}")
                exit(1)

    registry.update(entry_location, entry_metadata)

    all_files = []
    total_size = 0

    # Traverse all files in the directory `path`
    for file in path.rglob("*"):
        if not file.is_file():
            continue

        relative = file.relative_to(path)

        # Don't upload metadata file.
        if file == metadata_path:
            continue

        # Don't upload backup files.
        if file.name.endswith("~"):
            continue

        # Don't upload configuration files.
        if relative.parts[0] == ".nearai":
            continue

        size = file.stat().st_size
        total_size += size

        all_files.append((file, relative, size))

    pbar = tqdm(total=total_size, unit="B", unit_scale=True, disable=not show_progress)
    for file, relative, size in all_files:
        registry.upload_file(entry_location, file, relative)
        pbar.update(size)

    return entry_location
upload_file
upload_file(entry_location: EntryLocation, local_path: Path, path: Path) -> bool

Upload a file to the registry.

Source code in nearai/registry.py
def upload_file(self, entry_location: EntryLocation, local_path: Path, path: Path) -> bool:
    """Upload a file to the registry."""
    with open(local_path, "rb") as file:
        data = file.read()

        try:
            self.api.upload_file_v1_registry_upload_file_post(
                path=str(path),
                file=data,
                namespace=entry_location.namespace,
                name=entry_location.name,
                version=entry_location.version,
            )
            return True
        except BadRequestException as e:
            if isinstance(e.body, str) and "already exists" in e.body:
                return False

            raise e

get_namespace

get_namespace(local_path: Path) -> str

Returns namespace of an item or user namespace.

Source code in nearai/registry.py
def get_namespace(local_path: Path) -> str:
    """Returns namespace of an item or user namespace."""
    registry_folder = get_registry_folder()

    try:
        # Check if the path matches the expected structure
        relative_path = local_path.relative_to(registry_folder)
        parts = relative_path.parts

        # If the path has 3 parts (namespace, item_name, version),
        # return the first part as the namespace
        if len(parts) == 3:
            return str(parts[0])
    except ValueError:
        # relative_to() raises ValueError if local_path is not relative to registry_folder
        pass

    # If we couldn't extract a namespace from the path, return the default
    if CONFIG.auth is None:
        raise ValueError("AuthData is None")
    return CONFIG.auth.namespace

get_registry_folder

get_registry_folder() -> Path

Path to local registry.

Source code in nearai/registry.py
def get_registry_folder() -> Path:
    """Path to local registry."""
    return DATA_FOLDER / REGISTRY_FOLDER

shared

auth_data

AuthData

Bases: BaseModel

Source code in nearai/shared/auth_data.py
class AuthData(BaseModel):
    account_id: str
    signature: str
    public_key: str
    callback_url: str
    nonce: str
    recipient: str
    message: str
    on_behalf_of: Optional[str] = None

    def generate_bearer_token(self):
        """Generates a JSON-encoded bearer token containing authentication data."""
        required_keys = {"account_id", "public_key", "signature", "callback_url", "message", "nonce", "recipient"}

        for key in required_keys:
            if getattr(self, key) is None:
                raise ValueError(f"Missing required auth data: {key}")

        if self.on_behalf_of is not None:
            required_keys.add("on_behalf_of")

        bearer_data = {key: getattr(self, key) for key in required_keys}

        return json.dumps(bearer_data)

    @property
    def namespace(self):
        """Get the account ID for the auth data.

        In case you are running a request on behalf of another account, this will return the account ID of the account.
        """
        if self.on_behalf_of is not None:
            return self.on_behalf_of
        return self.account_id
namespace property
namespace

Get the account ID for the auth data.

In case you are running a request on behalf of another account, this will return the account ID of the account.

generate_bearer_token
generate_bearer_token()

Generates a JSON-encoded bearer token containing authentication data.

Source code in nearai/shared/auth_data.py
def generate_bearer_token(self):
    """Generates a JSON-encoded bearer token containing authentication data."""
    required_keys = {"account_id", "public_key", "signature", "callback_url", "message", "nonce", "recipient"}

    for key in required_keys:
        if getattr(self, key) is None:
            raise ValueError(f"Missing required auth data: {key}")

    if self.on_behalf_of is not None:
        required_keys.add("on_behalf_of")

    bearer_data = {key: getattr(self, key) for key in required_keys}

    return json.dumps(bearer_data)

cache

mem_cache_with_timeout
mem_cache_with_timeout(timeout: int)

Decorator to cache function results for a specified timeout period.

Source code in nearai/shared/cache.py
def mem_cache_with_timeout(timeout: int):
    """Decorator to cache function results for a specified timeout period."""

    def decorator(func):
        cache = {}

        @wraps(func)
        def wrapper(*args, **kwargs):
            now = time.time()
            key = (args, frozenset(kwargs.items()))
            if key in cache:
                result, timestamp = cache[key]
                if now - timestamp < timeout:
                    return result
            result = func(*args, **kwargs)
            cache[key] = (result, now)
            return result

        return wrapper

    return decorator

client_config

ClientConfig

Bases: BaseModel

Source code in nearai/shared/client_config.py
class ClientConfig(BaseModel):
    base_url: str = "https://api.near.ai/v1"
    custom_llm_provider: str = "openai"
    auth: Optional[AuthData] = None
    default_provider: Optional[str] = None  # future: remove in favor of api decision
    num_inference_retries: int = 1

    def get_hub_client(self):
        """Get the hub client."""
        signature = f"Bearer {self.auth.model_dump_json()}"
        base_url = self.base_url
        return openai.OpenAI(
            base_url=base_url, api_key=signature, timeout=DEFAULT_TIMEOUT, max_retries=DEFAULT_MAX_RETRIES
        )
get_hub_client
get_hub_client()

Get the hub client.

Source code in nearai/shared/client_config.py
def get_hub_client(self):
    """Get the hub client."""
    signature = f"Bearer {self.auth.model_dump_json()}"
    base_url = self.base_url
    return openai.OpenAI(
        base_url=base_url, api_key=signature, timeout=DEFAULT_TIMEOUT, max_retries=DEFAULT_MAX_RETRIES
    )

inference_client

InferenceClient

Bases: object

Source code in nearai/shared/inference_client.py
class InferenceClient(object):
    def __init__(self, config: ClientConfig, runner_api_key: str = "", agent_identifier: str = "") -> None:  # noqa: D107
        self._config = config
        self.runner_api_key = runner_api_key
        self.agent_identifier = agent_identifier
        self._auth = None
        self.generate_auth_for_current_agent(config, agent_identifier)
        self.client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)

    def generate_auth_for_current_agent(self, config, agent_identifier):
        """Regenerate auth for the current agent."""
        self.agent_identifier = agent_identifier
        if config.auth is not None:
            auth_bearer_token = config.auth.generate_bearer_token()
            new_token = json.loads(auth_bearer_token)
            new_token["runner_data"] = json.dumps({"agent": agent_identifier, "runner_api_key": self.runner_api_key})
            auth_bearer_token = json.dumps(new_token)
            self._auth = auth_bearer_token
        else:
            self._auth = None

    # This makes sense in the CLI where we don't mind doing this request and caching it.
    # In the aws_runner this is an extra request every time we run.
    # TODO(#233): add a choice of a provider model in aws_runner, and then this step can be skipped.
    @cached_property
    def provider_models(self) -> ProviderModels:  # noqa: D102
        return ProviderModels(self._config)

    def get_agent_public_key(self, agent_name: str) -> str:
        """Request agent public key."""
        headers = {
            "Content-Type": "application/json",
        }

        data = {"agent_name": agent_name}

        endpoint = f"{self._config.base_url}/get_agent_public_key"

        try:
            response = requests.post(endpoint, headers=headers, params=data)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            raise ValueError(f"Failed to get agent public key: {e}") from None

    def completions(
        self,
        model: str,
        messages: Iterable[ChatCompletionMessageParam],
        stream: bool = False,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        **kwargs: Any,
    ) -> Union[ModelResponse, CustomStreamWrapper]:
        """Takes a `model` and `messages` and returns completions.

        `model` can be:
        1. full path `provider::model_full_path`.
        2. `model_short_name`. Default provider will be used.
        """
        provider, model = self.provider_models.match_provider_model(model)

        if temperature is None:
            temperature = DEFAULT_MODEL_TEMPERATURE

        if max_tokens is None:
            max_tokens = DEFAULT_MODEL_MAX_TOKENS

        # NOTE(#246): this is to disable "Provider List" messages.
        litellm.suppress_debug_info = True

        for i in range(0, self._config.num_inference_retries):
            try:
                result: Union[ModelResponse, CustomStreamWrapper] = litellm_completion(
                    model,
                    messages,
                    stream=stream,
                    custom_llm_provider=self._config.custom_llm_provider,
                    input_cost_per_token=0,
                    output_cost_per_token=0,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    base_url=self._config.base_url,
                    provider=provider,
                    api_key=self._auth,
                    **kwargs,
                )
                break
            except Exception as e:
                if i == self._config.num_inference_retries - 1:
                    raise ValueError(f"Bad request: {e}") from None

        return result

    def query_vector_store(
        self, vector_store_id: str, query: str, full_files: bool = False
    ) -> Union[List[SimilaritySearch], List[SimilaritySearchFile]]:
        """Query a vector store."""
        if self._config is None:
            raise ValueError("Missing NEAR AI Hub config")

        auth_bearer_token = self._auth

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {auth_bearer_token}",
        }

        data = {"query": query, "full_files": full_files}

        endpoint = f"{self._config.base_url}/vector_stores/{vector_store_id}/search"

        try:
            response = requests.post(endpoint, headers=headers, json=data)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            raise ValueError(f"Error querying vector store: {e}") from None

    def upload_file(
        self,
        file_content: str,
        purpose: Literal["assistants", "batch", "fine-tune", "vision"],
        encoding: str = "utf-8",
        file_name="file.txt",
        file_type="text/plain",
    ) -> FileObject:
        """Uploads a file."""
        client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
        file_data = io.BytesIO(file_content.encode(encoding))
        return client.files.create(file=(file_name, file_data, file_type), purpose=purpose)

    def add_file_to_vector_store(self, vector_store_id: str, file_id: str) -> VectorStoreFile:
        """Adds a file to vector store."""
        client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
        return client.beta.vector_stores.files.create(vector_store_id=vector_store_id, file_id=file_id)

    def create_vector_store_from_source(
        self,
        name: str,
        source: Union[GitHubSource, GitLabSource],
        source_auth: Optional[str] = None,
        chunking_strategy: Optional[ChunkingStrategy] = None,
        expires_after: Optional[ExpiresAfter] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> VectorStore:
        """Creates a vector store from the given source.

        Args:
        ----
            name (str): The name of the vector store.
            source (Union[GitHubSource, GitLabSource]): The source from which to create the vector store.
            source_auth (Optional[str]): The source authentication token.
            chunking_strategy (Optional[ChunkingStrategy]): The chunking strategy to use.
            expires_after (Optional[ExpiresAfter]): The expiration policy.
            metadata (Optional[Dict[str, str]]): Additional metadata.

        Returns:
        -------
            VectorStore: The created vector store.

        """
        print(f"Creating vector store from source: {source}")
        headers = {
            "Authorization": f"Bearer {self._auth}",
            "Content-Type": "application/json",
        }
        data = {
            "name": name,
            "source": source,
            "source_auth": source_auth,
            "chunking_strategy": chunking_strategy,
            "expires_after": expires_after,
            "metadata": metadata,
        }
        endpoint = f"{self._config.base_url}/vector_stores/from_source"

        try:
            response = requests.post(endpoint, headers=headers, json=data)
            print(response.json())
            response.raise_for_status()
            return VectorStore(**response.json())
        except requests.RequestException as e:
            raise ValueError(f"Failed to create vector store: {e}") from None

    def create_vector_store(
        self,
        name: str,
        file_ids: List[str],
        expires_after: Union[ExpiresAfter, NotGiven] = NOT_GIVEN,
        chunking_strategy: Union[AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, NotGiven] = NOT_GIVEN,
        metadata: Optional[Dict[str, str]] = None,
    ) -> VectorStore:
        """Creates Vector Store.

        :param name: Vector store name.
        :param file_ids: Files to be added to the vector store.
        :param expires_after: Expiration policy.
        :param chunking_strategy: Chunking strategy.
        :param metadata: Additional metadata.
        :return: Returns the created vector store or error.
        """
        client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
        return client.beta.vector_stores.create(
            file_ids=file_ids,
            name=name,
            expires_after=expires_after,
            chunking_strategy=chunking_strategy,
            metadata=metadata,
        )

    def get_vector_store(self, vector_store_id: str) -> VectorStore:
        """Gets a vector store by id."""
        endpoint = f"{self._config.base_url}/vector_stores/{vector_store_id}"
        response = requests.get(endpoint)
        response.raise_for_status()
        return VectorStore(**response.json())

    def create_thread(self, messages):
        """Create a thread."""
        return self.client.beta.threads.create(messages=messages)

    def threads_messages_create(self, thread_id: str, content: str, role: Literal["user", "assistant"]):
        """Create a message in a thread."""
        return self.client.beta.threads.messages.create(thread_id=thread_id, content=content, role=role)

    def threads_create_and_run_poll(self, assistant_id: str, model: str, messages: List[ChatCompletionMessageParam]):
        """Create a thread and run the assistant."""
        thread = self.create_thread(messages)
        return self.client.beta.threads.create_and_run_poll(thread=thread, assistant_id=assistant_id, model=model)

    def threads_list_messages(self, thread_id: str, order: Literal["asc", "desc"] = "asc"):
        """List messages in a thread."""
        return self.client.beta.threads.messages.list(thread_id=thread_id, order=order)

    def threads_fork(self, thread_id: str):
        """Fork a thread."""
        forked_thread = self.client.post(path=f"{self._config.base_url}/threads/{thread_id}/fork", cast_to=Thread)
        return forked_thread

    def threads_runs_create(self, thread_id: str, assistant_id: str, model: str):
        """Create a run in a thread."""
        return self.client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, model=model)

    def run_agent(self, current_run_id: str, child_thread_id: str, assistant_id: str):
        """Starts a child agent run from a parent agent run."""
        return self.client.beta.threads.runs.create(
            thread_id=child_thread_id,
            assistant_id=assistant_id,
            extra_body={"parent_run_id": current_run_id},
        )

    def schedule_run(
        self,
        agent: str,
        input_message: str,
        thread_id: Optional[str],
        run_params: Optional[Dict[str, str]],
        run_at: datetime,
    ):
        """Query a vector store."""
        if self._config is None:
            raise ValueError("Missing NearAI Hub config")

        auth_bearer_token = self._auth

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {auth_bearer_token}",
        }

        if run_params is None:
            run_params = {}

        data = {
            "agent": agent,
            "input_message": input_message,
            "thread_id": thread_id,
            "run_params": run_params,
            "run_at": run_at,
        }

        endpoint = f"{self._config.base_url}/schedule_run"

        try:
            response = requests.post(endpoint, headers=headers, json=data)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            raise ValueError(f"Error querying schedule_run: {e}") from None

    def query_user_memory(self, query: str):
        """Query the user memory."""
        return self.client.post(
            path=f"{self._config.base_url}/vector_stores/memory/query",
            body={"query": query},
            cast_to=str,
        )

    def add_user_memory(self, memory: str):
        """Add user memory."""
        return self.client.post(
            path=f"{self._config.base_url}/vector_stores/memory",
            body={"memory": memory},
            cast_to=str,
        )

    def generate_image(self, prompt: str):
        """Generate an image."""
        return self.client.images.generate(prompt=prompt)

    def save_agent_data(self, key: str, agent_data: Dict[str, Any]):
        """Save agent data for the agent this client was initialized with."""
        return self.client.post(
            path=f"{self._config.base_url}/agent_data",
            body={
                "key": key,
                "value": agent_data,
            },
            cast_to=Dict[str, Any],
        )

    def get_agent_data(self):
        """Get agent data for the agent this client was initialized with."""
        return self.client.get(
            path=f"{self._config.base_url}/agent_data",
            cast_to=Dict[str, str],
        )

    def get_agent_data_by_key(self, key: str):
        """Get agent data by key for the agent this client was initialized with."""
        return self.client.get(
            path=f"{self._config.base_url}/agent_data/{key}",
            cast_to=Dict[str, str],
        )
add_file_to_vector_store
add_file_to_vector_store(vector_store_id: str, file_id: str) -> VectorStoreFile

Adds a file to vector store.

Source code in nearai/shared/inference_client.py
def add_file_to_vector_store(self, vector_store_id: str, file_id: str) -> VectorStoreFile:
    """Adds a file to vector store."""
    client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
    return client.beta.vector_stores.files.create(vector_store_id=vector_store_id, file_id=file_id)
add_user_memory
add_user_memory(memory: str)

Add user memory.

Source code in nearai/shared/inference_client.py
def add_user_memory(self, memory: str):
    """Add user memory."""
    return self.client.post(
        path=f"{self._config.base_url}/vector_stores/memory",
        body={"memory": memory},
        cast_to=str,
    )
completions
completions(model: str, messages: Iterable[ChatCompletionMessageParam], stream: bool = False, temperature: Optional[float] = None, max_tokens: Optional[int] = None, **kwargs: Any) -> Union[ModelResponse, CustomStreamWrapper]

Takes a model and messages and returns completions.

model can be: 1. full path provider::model_full_path. 2. model_short_name. Default provider will be used.

Source code in nearai/shared/inference_client.py
def completions(
    self,
    model: str,
    messages: Iterable[ChatCompletionMessageParam],
    stream: bool = False,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    **kwargs: Any,
) -> Union[ModelResponse, CustomStreamWrapper]:
    """Takes a `model` and `messages` and returns completions.

    `model` can be:
    1. full path `provider::model_full_path`.
    2. `model_short_name`. Default provider will be used.
    """
    provider, model = self.provider_models.match_provider_model(model)

    if temperature is None:
        temperature = DEFAULT_MODEL_TEMPERATURE

    if max_tokens is None:
        max_tokens = DEFAULT_MODEL_MAX_TOKENS

    # NOTE(#246): this is to disable "Provider List" messages.
    litellm.suppress_debug_info = True

    for i in range(0, self._config.num_inference_retries):
        try:
            result: Union[ModelResponse, CustomStreamWrapper] = litellm_completion(
                model,
                messages,
                stream=stream,
                custom_llm_provider=self._config.custom_llm_provider,
                input_cost_per_token=0,
                output_cost_per_token=0,
                temperature=temperature,
                max_tokens=max_tokens,
                base_url=self._config.base_url,
                provider=provider,
                api_key=self._auth,
                **kwargs,
            )
            break
        except Exception as e:
            if i == self._config.num_inference_retries - 1:
                raise ValueError(f"Bad request: {e}") from None

    return result
create_thread
create_thread(messages)

Create a thread.

Source code in nearai/shared/inference_client.py
def create_thread(self, messages):
    """Create a thread."""
    return self.client.beta.threads.create(messages=messages)
create_vector_store
create_vector_store(name: str, file_ids: List[str], expires_after: Union[ExpiresAfter, NotGiven] = NOT_GIVEN, chunking_strategy: Union[AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, NotGiven] = NOT_GIVEN, metadata: Optional[Dict[str, str]] = None) -> VectorStore

Creates Vector Store.

:param name: Vector store name. :param file_ids: Files to be added to the vector store. :param expires_after: Expiration policy. :param chunking_strategy: Chunking strategy. :param metadata: Additional metadata. :return: Returns the created vector store or error.

Source code in nearai/shared/inference_client.py
def create_vector_store(
    self,
    name: str,
    file_ids: List[str],
    expires_after: Union[ExpiresAfter, NotGiven] = NOT_GIVEN,
    chunking_strategy: Union[AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, NotGiven] = NOT_GIVEN,
    metadata: Optional[Dict[str, str]] = None,
) -> VectorStore:
    """Creates Vector Store.

    :param name: Vector store name.
    :param file_ids: Files to be added to the vector store.
    :param expires_after: Expiration policy.
    :param chunking_strategy: Chunking strategy.
    :param metadata: Additional metadata.
    :return: Returns the created vector store or error.
    """
    client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
    return client.beta.vector_stores.create(
        file_ids=file_ids,
        name=name,
        expires_after=expires_after,
        chunking_strategy=chunking_strategy,
        metadata=metadata,
    )
create_vector_store_from_source
create_vector_store_from_source(name: str, source: Union[GitHubSource, GitLabSource], source_auth: Optional[str] = None, chunking_strategy: Optional[ChunkingStrategy] = None, expires_after: Optional[ExpiresAfter] = None, metadata: Optional[Dict[str, str]] = None) -> VectorStore

Creates a vector store from the given source.


name (str): The name of the vector store.
source (Union[GitHubSource, GitLabSource]): The source from which to create the vector store.
source_auth (Optional[str]): The source authentication token.
chunking_strategy (Optional[ChunkingStrategy]): The chunking strategy to use.
expires_after (Optional[ExpiresAfter]): The expiration policy.
metadata (Optional[Dict[str, str]]): Additional metadata.

VectorStore: The created vector store.
Source code in nearai/shared/inference_client.py
def create_vector_store_from_source(
    self,
    name: str,
    source: Union[GitHubSource, GitLabSource],
    source_auth: Optional[str] = None,
    chunking_strategy: Optional[ChunkingStrategy] = None,
    expires_after: Optional[ExpiresAfter] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> VectorStore:
    """Creates a vector store from the given source.

    Args:
    ----
        name (str): The name of the vector store.
        source (Union[GitHubSource, GitLabSource]): The source from which to create the vector store.
        source_auth (Optional[str]): The source authentication token.
        chunking_strategy (Optional[ChunkingStrategy]): The chunking strategy to use.
        expires_after (Optional[ExpiresAfter]): The expiration policy.
        metadata (Optional[Dict[str, str]]): Additional metadata.

    Returns:
    -------
        VectorStore: The created vector store.

    """
    print(f"Creating vector store from source: {source}")
    headers = {
        "Authorization": f"Bearer {self._auth}",
        "Content-Type": "application/json",
    }
    data = {
        "name": name,
        "source": source,
        "source_auth": source_auth,
        "chunking_strategy": chunking_strategy,
        "expires_after": expires_after,
        "metadata": metadata,
    }
    endpoint = f"{self._config.base_url}/vector_stores/from_source"

    try:
        response = requests.post(endpoint, headers=headers, json=data)
        print(response.json())
        response.raise_for_status()
        return VectorStore(**response.json())
    except requests.RequestException as e:
        raise ValueError(f"Failed to create vector store: {e}") from None
generate_auth_for_current_agent
generate_auth_for_current_agent(config, agent_identifier)

Regenerate auth for the current agent.

Source code in nearai/shared/inference_client.py
def generate_auth_for_current_agent(self, config, agent_identifier):
    """Regenerate auth for the current agent."""
    self.agent_identifier = agent_identifier
    if config.auth is not None:
        auth_bearer_token = config.auth.generate_bearer_token()
        new_token = json.loads(auth_bearer_token)
        new_token["runner_data"] = json.dumps({"agent": agent_identifier, "runner_api_key": self.runner_api_key})
        auth_bearer_token = json.dumps(new_token)
        self._auth = auth_bearer_token
    else:
        self._auth = None
generate_image
generate_image(prompt: str)

Generate an image.

Source code in nearai/shared/inference_client.py
def generate_image(self, prompt: str):
    """Generate an image."""
    return self.client.images.generate(prompt=prompt)
get_agent_data
get_agent_data()

Get agent data for the agent this client was initialized with.

Source code in nearai/shared/inference_client.py
def get_agent_data(self):
    """Get agent data for the agent this client was initialized with."""
    return self.client.get(
        path=f"{self._config.base_url}/agent_data",
        cast_to=Dict[str, str],
    )
get_agent_data_by_key
get_agent_data_by_key(key: str)

Get agent data by key for the agent this client was initialized with.

Source code in nearai/shared/inference_client.py
def get_agent_data_by_key(self, key: str):
    """Get agent data by key for the agent this client was initialized with."""
    return self.client.get(
        path=f"{self._config.base_url}/agent_data/{key}",
        cast_to=Dict[str, str],
    )
get_agent_public_key
get_agent_public_key(agent_name: str) -> str

Request agent public key.

Source code in nearai/shared/inference_client.py
def get_agent_public_key(self, agent_name: str) -> str:
    """Request agent public key."""
    headers = {
        "Content-Type": "application/json",
    }

    data = {"agent_name": agent_name}

    endpoint = f"{self._config.base_url}/get_agent_public_key"

    try:
        response = requests.post(endpoint, headers=headers, params=data)
        response.raise_for_status()
        return response.json()
    except requests.RequestException as e:
        raise ValueError(f"Failed to get agent public key: {e}") from None
get_vector_store
get_vector_store(vector_store_id: str) -> VectorStore

Gets a vector store by id.

Source code in nearai/shared/inference_client.py
def get_vector_store(self, vector_store_id: str) -> VectorStore:
    """Gets a vector store by id."""
    endpoint = f"{self._config.base_url}/vector_stores/{vector_store_id}"
    response = requests.get(endpoint)
    response.raise_for_status()
    return VectorStore(**response.json())
query_user_memory
query_user_memory(query: str)

Query the user memory.

Source code in nearai/shared/inference_client.py
def query_user_memory(self, query: str):
    """Query the user memory."""
    return self.client.post(
        path=f"{self._config.base_url}/vector_stores/memory/query",
        body={"query": query},
        cast_to=str,
    )
query_vector_store
query_vector_store(vector_store_id: str, query: str, full_files: bool = False) -> Union[List[SimilaritySearch], List[SimilaritySearchFile]]

Query a vector store.

Source code in nearai/shared/inference_client.py
def query_vector_store(
    self, vector_store_id: str, query: str, full_files: bool = False
) -> Union[List[SimilaritySearch], List[SimilaritySearchFile]]:
    """Query a vector store."""
    if self._config is None:
        raise ValueError("Missing NEAR AI Hub config")

    auth_bearer_token = self._auth

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {auth_bearer_token}",
    }

    data = {"query": query, "full_files": full_files}

    endpoint = f"{self._config.base_url}/vector_stores/{vector_store_id}/search"

    try:
        response = requests.post(endpoint, headers=headers, json=data)
        response.raise_for_status()
        return response.json()
    except requests.RequestException as e:
        raise ValueError(f"Error querying vector store: {e}") from None
run_agent
run_agent(current_run_id: str, child_thread_id: str, assistant_id: str)

Starts a child agent run from a parent agent run.

Source code in nearai/shared/inference_client.py
def run_agent(self, current_run_id: str, child_thread_id: str, assistant_id: str):
    """Starts a child agent run from a parent agent run."""
    return self.client.beta.threads.runs.create(
        thread_id=child_thread_id,
        assistant_id=assistant_id,
        extra_body={"parent_run_id": current_run_id},
    )
save_agent_data
save_agent_data(key: str, agent_data: Dict[str, Any])

Save agent data for the agent this client was initialized with.

Source code in nearai/shared/inference_client.py
def save_agent_data(self, key: str, agent_data: Dict[str, Any]):
    """Save agent data for the agent this client was initialized with."""
    return self.client.post(
        path=f"{self._config.base_url}/agent_data",
        body={
            "key": key,
            "value": agent_data,
        },
        cast_to=Dict[str, Any],
    )
schedule_run
schedule_run(agent: str, input_message: str, thread_id: Optional[str], run_params: Optional[Dict[str, str]], run_at: datetime)

Query a vector store.

Source code in nearai/shared/inference_client.py
def schedule_run(
    self,
    agent: str,
    input_message: str,
    thread_id: Optional[str],
    run_params: Optional[Dict[str, str]],
    run_at: datetime,
):
    """Query a vector store."""
    if self._config is None:
        raise ValueError("Missing NearAI Hub config")

    auth_bearer_token = self._auth

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {auth_bearer_token}",
    }

    if run_params is None:
        run_params = {}

    data = {
        "agent": agent,
        "input_message": input_message,
        "thread_id": thread_id,
        "run_params": run_params,
        "run_at": run_at,
    }

    endpoint = f"{self._config.base_url}/schedule_run"

    try:
        response = requests.post(endpoint, headers=headers, json=data)
        response.raise_for_status()
        return response.json()
    except requests.RequestException as e:
        raise ValueError(f"Error querying schedule_run: {e}") from None
threads_create_and_run_poll
threads_create_and_run_poll(assistant_id: str, model: str, messages: List[ChatCompletionMessageParam])

Create a thread and run the assistant.

Source code in nearai/shared/inference_client.py
def threads_create_and_run_poll(self, assistant_id: str, model: str, messages: List[ChatCompletionMessageParam]):
    """Create a thread and run the assistant."""
    thread = self.create_thread(messages)
    return self.client.beta.threads.create_and_run_poll(thread=thread, assistant_id=assistant_id, model=model)
threads_fork
threads_fork(thread_id: str)

Fork a thread.

Source code in nearai/shared/inference_client.py
def threads_fork(self, thread_id: str):
    """Fork a thread."""
    forked_thread = self.client.post(path=f"{self._config.base_url}/threads/{thread_id}/fork", cast_to=Thread)
    return forked_thread
threads_list_messages
threads_list_messages(thread_id: str, order: Literal['asc', 'desc'] = 'asc')

List messages in a thread.

Source code in nearai/shared/inference_client.py
def threads_list_messages(self, thread_id: str, order: Literal["asc", "desc"] = "asc"):
    """List messages in a thread."""
    return self.client.beta.threads.messages.list(thread_id=thread_id, order=order)
threads_messages_create
threads_messages_create(thread_id: str, content: str, role: Literal['user', 'assistant'])

Create a message in a thread.

Source code in nearai/shared/inference_client.py
def threads_messages_create(self, thread_id: str, content: str, role: Literal["user", "assistant"]):
    """Create a message in a thread."""
    return self.client.beta.threads.messages.create(thread_id=thread_id, content=content, role=role)
threads_runs_create
threads_runs_create(thread_id: str, assistant_id: str, model: str)

Create a run in a thread.

Source code in nearai/shared/inference_client.py
def threads_runs_create(self, thread_id: str, assistant_id: str, model: str):
    """Create a run in a thread."""
    return self.client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, model=model)
upload_file
upload_file(file_content: str, purpose: Literal['assistants', 'batch', 'fine-tune', 'vision'], encoding: str = 'utf-8', file_name='file.txt', file_type='text/plain') -> FileObject

Uploads a file.

Source code in nearai/shared/inference_client.py
def upload_file(
    self,
    file_content: str,
    purpose: Literal["assistants", "batch", "fine-tune", "vision"],
    encoding: str = "utf-8",
    file_name="file.txt",
    file_type="text/plain",
) -> FileObject:
    """Uploads a file."""
    client = openai.OpenAI(base_url=self._config.base_url, api_key=self._auth)
    file_data = io.BytesIO(file_content.encode(encoding))
    return client.files.create(file=(file_name, file_data, file_type), purpose=purpose)

models

AutoFileChunkingStrategyParam

Bases: TypedDict

Source code in nearai/shared/models.py
class AutoFileChunkingStrategyParam(TypedDict, total=False):
    type: Required[Literal["auto"]]
    """Always `auto`."""
type instance-attribute
type: Required[Literal['auto']]

Always auto.

ChunkingStrategy

Bases: BaseModel

Defines the chunking strategy for vector stores.

Source code in nearai/shared/models.py
class ChunkingStrategy(BaseModel):
    """Defines the chunking strategy for vector stores."""

    pass
CreateVectorStoreRequest

Bases: BaseModel

Request model for creating a new vector store.

Source code in nearai/shared/models.py
class CreateVectorStoreRequest(BaseModel):
    """Request model for creating a new vector store."""

    chunking_strategy: Union[AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, None] = None
    """The chunking strategy to use for the vector store."""
    expires_after: Optional[ExpiresAfter] = None
    """The expiration time for the vector store."""
    file_ids: Optional[List[str]] = None
    """The file IDs to attach to the vector store."""
    metadata: Optional[Dict[str, str]] = None
    """The metadata to attach to the vector store."""
    name: str
    """The name of the vector store."""
chunking_strategy class-attribute instance-attribute
chunking_strategy: Union[AutoFileChunkingStrategyParam, StaticFileChunkingStrategyParam, None] = None

The chunking strategy to use for the vector store.

expires_after class-attribute instance-attribute
expires_after: Optional[ExpiresAfter] = None

The expiration time for the vector store.

file_ids class-attribute instance-attribute
file_ids: Optional[List[str]] = None

The file IDs to attach to the vector store.

metadata class-attribute instance-attribute
metadata: Optional[Dict[str, str]] = None

The metadata to attach to the vector store.

name instance-attribute
name: str

The name of the vector store.

ExpiresAfter

Bases: TypedDict

Source code in nearai/shared/models.py
class ExpiresAfter(TypedDict, total=False):
    anchor: Required[Literal["last_active_at"]]
    """Anchor timestamp after which the expiration policy applies.

    Supported anchors: `last_active_at`.
    """

    days: Required[int]
    """The number of days after the anchor time that the vector store will expire."""
anchor instance-attribute
anchor: Required[Literal['last_active_at']]

Anchor timestamp after which the expiration policy applies.

Supported anchors: last_active_at.

days instance-attribute
days: Required[int]

The number of days after the anchor time that the vector store will expire.

StaticFileChunkingStrategyParam

Bases: TypedDict

Source code in nearai/shared/models.py
class StaticFileChunkingStrategyParam(TypedDict, total=False):
    chunk_overlap_tokens: Required[int]
    """The number of tokens that overlap between chunks. The default value is `400`.

    Note that the overlap must not exceed half of `max_chunk_size_tokens`.
    """

    max_chunk_size_tokens: Required[int]
    """The maximum number of tokens in each chunk.

    The default value is `800`. The minimum value is `100` and the maximum value is
    `4096`.
    """
chunk_overlap_tokens instance-attribute
chunk_overlap_tokens: Required[int]

The number of tokens that overlap between chunks. The default value is 400.

Note that the overlap must not exceed half of max_chunk_size_tokens.

max_chunk_size_tokens instance-attribute
max_chunk_size_tokens: Required[int]

The maximum number of tokens in each chunk.

The default value is 800. The minimum value is 100 and the maximum value is 4096.

VectorStoreFileCreate

Bases: BaseModel

Request model for creating a vector store file.

Source code in nearai/shared/models.py
class VectorStoreFileCreate(BaseModel):
    """Request model for creating a vector store file."""

    file_id: str
    """File ID returned from upload file endpoint."""
file_id instance-attribute
file_id: str

File ID returned from upload file endpoint.

naming

NamespacedName
Source code in nearai/shared/naming.py
class NamespacedName:
    def __init__(self, name: str, namespace: str = ""):  # noqa: D107
        self.name = name
        self.namespace = namespace

    def __eq__(self, other):  # noqa: D105
        if not isinstance(other, NamespacedName):
            return NotImplemented
        return self.name == other.name and self.namespace == other.namespace

    def __hash__(self):  # noqa: D105
        return hash((self.name, self.namespace))

    def __str__(self):  # noqa: D105
        if self.namespace:
            return f"{self.namespace}/{self.name}"
        return self.name

    def __repr__(self):  # noqa: D105
        return f"NamespacedName(name='{self.name}', namespace='{self.namespace}')"

    def canonical(self) -> "NamespacedName":  # noqa: D105
        """Returns canonical NamespacedName."""
        return NamespacedName(
            name=get_canonical_name(self.name),
            namespace=get_canonical_name(self.namespace) if self.namespace != DEFAULT_NAMESPACE else "",
        )
canonical
canonical() -> NamespacedName

Returns canonical NamespacedName.

Source code in nearai/shared/naming.py
def canonical(self) -> "NamespacedName":  # noqa: D105
    """Returns canonical NamespacedName."""
    return NamespacedName(
        name=get_canonical_name(self.name),
        namespace=get_canonical_name(self.namespace) if self.namespace != DEFAULT_NAMESPACE else "",
    )
create_registry_name
create_registry_name(name: str) -> str

Formats name for a suitable registry name.

Source code in nearai/shared/naming.py
def create_registry_name(name: str) -> str:
    """Formats `name` for a suitable registry name."""
    # Convert to lowercase
    name = name.lower()
    # Convert '.' between digits to 'p'
    name = re.sub(r"(\d)\.(\d)", r"\1p\2", name)
    # Convert '<digit>v<digit>' -> '<digit>-<digit>'
    name = re.sub(r"(\d)v(\d)", r"\1-\2", name)
    # Convert '<not letter>v<digit>' -> '<not letter><digit>'
    name = re.sub(r"(^|[^a-z])v(\d)", r"\1\2", name)
    # Replace non-alphanumeric characters between digits with '-'
    name = re.sub(r"(\d)[^a-z0-9]+(\d)", r"\1-\2", name)
    # Remove remaining non-alphanumeric characters, except '-'
    name = re.sub(r"[^a-z0-9-]", "", name)
    # Convert 'metallama' or 'meta-llama' to 'llama'
    name = name.replace("metallama", "llama")
    name = name.replace("meta-llama", "llama")
    # Convert 'qwenq' or 'qwen-q' to 'q'
    name = name.replace("qwenq", "q")
    name = name.replace("qwen-q", "q")
    return name
get_canonical_name
get_canonical_name(name: str) -> str

Returns a name that can be used for matching entities.

Applies such transformations: 1. All letters lowercase. 2. Convert '.' between digits to 'p'. 3. Convert 'v' -> '' 4. Remove all non-alphanumeric characters except between digits. Use '_' between digits. 5. Convert 'metallama' -> 'llama'.

e.g. "llama-3.1-70b-instruct" -> "llama3p1_70binstruct"

Source code in nearai/shared/naming.py
def get_canonical_name(name: str) -> str:
    """Returns a name that can be used for matching entities.

    Applies such transformations:
    1. All letters lowercase.
    2. Convert '.' between digits to 'p'.
    3. Convert '<not letter>v<digit>' -> '<not letter><digit>'
    4. Remove all non-alphanumeric characters except between digits.
        Use '_' between digits.
    5. Convert 'metallama' -> 'llama'.

    e.g. "llama-3.1-70b-instruct" -> "llama3p1_70binstruct"
    """
    # Convert to lowercase
    name = name.lower()
    # Convert '.' between digits to 'p'
    name = re.sub(r"(\d)\.(\d)", r"\1p\2", name)
    # Convert '<digit>v<digit>' -> '<digit>_<digit>'
    name = re.sub(r"(\d)v(\d)", r"\1_\2", name)
    # Convert '<not letter>v<digit>' -> '<not letter><digit>'
    name = re.sub(r"(^|[^a-z])v(\d)", r"\1\2", name)
    # Replace non-alphanumeric characters between digits with '_'
    name = re.sub(r"(\d)[^a-z0-9]+(\d)", r"\1_\2", name)
    # Remove remaining non-alphanumeric characters, except '_'
    name = re.sub(r"[^a-z0-9_]", "", name)
    # Remove any remaining underscores that are not between digits
    name = re.sub(r"(?<!\d)_|_(?!\d)", "", name)
    # Convert 'metallama' to 'llama'
    name = name.replace("metallama", "llama")
    # Convert 'qwenq' to 'q'
    name = name.replace("qwenq", "q")
    return name

near

sign
SignatureVerificationResult

Bases: Enum

Source code in nearai/shared/near/sign.py
class SignatureVerificationResult(Enum):
    TRUE = True
    FALSE = False
    VERIFY_ACCESS_KEY_OWNER_SERVICE_NOT_AVAILABLE = "verify_access_key_owner_not_available"

    @classmethod
    def from_bool(cls, value: bool):
        """Gets VerificationResult based on a boolean value."""
        return cls.TRUE if value else cls.FALSE

    def __bool__(self):
        """Overrides the behavior when checking for truthiness."""
        return self == SignatureVerificationResult.TRUE
__bool__
__bool__()

Overrides the behavior when checking for truthiness.

Source code in nearai/shared/near/sign.py
def __bool__(self):
    """Overrides the behavior when checking for truthiness."""
    return self == SignatureVerificationResult.TRUE
from_bool classmethod
from_bool(value: bool)

Gets VerificationResult based on a boolean value.

Source code in nearai/shared/near/sign.py
@classmethod
def from_bool(cls, value: bool):
    """Gets VerificationResult based on a boolean value."""
    return cls.TRUE if value else cls.FALSE
convert_nonce
convert_nonce(value: Union[str, bytes, list[int]])

Converts a given value to a 32-byte nonce.

Source code in nearai/shared/near/sign.py
def convert_nonce(value: Union[str, bytes, list[int]]):
    """Converts a given value to a 32-byte nonce."""
    if isinstance(value, bytes):
        if len(value) > 32:
            raise ValueError("Invalid nonce length")
        if len(value) < 32:
            value = value.rjust(32, b"0")
        return value
    elif isinstance(value, str):
        nonce_bytes = value.encode("utf-8")
        if len(nonce_bytes) > 32:
            raise ValueError("Invalid nonce length")
        if len(nonce_bytes) < 32:
            nonce_bytes = nonce_bytes.rjust(32, b"0")
        return nonce_bytes
    elif isinstance(value, list):
        if len(value) != 32:
            raise ValueError("Invalid nonce length")
        return bytes(value)
    else:
        raise ValueError("Invalid nonce format")
create_inference_signature
create_inference_signature(private_key: str, payload: CompletionSignaturePayload) -> tuple[str, str]

Creates a cryptographic signature for a given extended inference payload using a specified private key.

Source code in nearai/shared/near/sign.py
def create_inference_signature(private_key: str, payload: CompletionSignaturePayload) -> tuple[str, str]:
    """Creates a cryptographic signature for a given extended inference payload using a specified private key."""
    borsh_payload = BinarySerializer(dict(COMPLETION_PAYLOAD_SCHEMA)).serialize(payload)

    to_sign = hashlib.sha256(borsh_payload).digest()

    private_key_base58 = private_key[len(ED_PREFIX) :]
    private_key_bytes = base58.b58decode(private_key_base58)

    if len(private_key_bytes) != 64:
        raise ValueError("The private key must be exactly 64 bytes long")

    private_key_seed = private_key_bytes[:32]

    signing_key = nacl.signing.SigningKey(private_key_seed)
    public_key = signing_key.verify_key

    signed = signing_key.sign(to_sign)
    signature = base64.b64encode(signed.signature).decode("utf-8")

    public_key_base58 = base58.b58encode(public_key.encode()).decode("utf-8")
    full_public_key = ED_PREFIX + public_key_base58

    return signature, full_public_key
create_signature
create_signature(private_key: str, payload: Payload) -> tuple[str, str]

Creates a cryptographic signature for a given payload using a specified private key.

Source code in nearai/shared/near/sign.py
def create_signature(private_key: str, payload: Payload) -> tuple[str, str]:
    """Creates a cryptographic signature for a given payload using a specified private key."""
    borsh_payload = BinarySerializer(dict(PAYLOAD_SCHEMA)).serialize(payload)

    to_sign = hashlib.sha256(borsh_payload).digest()

    # Extract and decode the private key
    private_key_base58 = private_key[len(ED_PREFIX) :]
    private_key_bytes = base58.b58decode(private_key_base58)

    if len(private_key_bytes) != 64:
        raise ValueError("The private key must be exactly 64 bytes long")

    # Use only the first 32 bytes as the seed
    private_key_seed = private_key_bytes[:32]

    signing_key = nacl.signing.SigningKey(private_key_seed)
    public_key = signing_key.verify_key

    signed = signing_key.sign(to_sign)
    signature = base64.b64encode(signed.signature).decode("utf-8")

    public_key_base58 = base58.b58encode(public_key.encode()).decode("utf-8")
    full_public_key = ED_PREFIX + public_key_base58

    return signature, full_public_key
validate_completion_signature
validate_completion_signature(public_key: str, signature: str, payload: CompletionSignaturePayload)

Validates a cryptographic signature for a given payload using a specified public key.

Source code in nearai/shared/near/sign.py
def validate_completion_signature(public_key: str, signature: str, payload: CompletionSignaturePayload):
    """Validates a cryptographic signature for a given payload using a specified public key."""
    borsh_payload = BinarySerializer(dict(COMPLETION_PAYLOAD_SCHEMA)).serialize(payload)
    to_sign = hashlib.sha256(borsh_payload).digest()
    real_signature = base64.b64decode(signature)

    verify_key: nacl.signing.VerifyKey = nacl.signing.VerifyKey(base58.b58decode(public_key[len(ED_PREFIX) :]))

    try:
        verify_key.verify(to_sign, real_signature)
        return True
    except nacl.exceptions.BadSignatureError:
        return False
validate_nonce
validate_nonce(value: Union[str, bytes, list[int]])

Ensures that the nonce is a valid timestamp.

Source code in nearai/shared/near/sign.py
def validate_nonce(value: Union[str, bytes, list[int]]):
    """Ensures that the nonce is a valid timestamp."""
    nonce = convert_nonce(value)
    nonce_int = int(nonce.decode("utf-8"))

    now = int(time.time() * 1000)

    if nonce_int > now:
        raise ValueError("Nonce is in the future")
    if now - nonce_int > 10 * 365 * 24 * 60 * 60 * 1000:
        """If the timestamp is older than 10 years, it is considered invalid. Forcing apps to use unique nonces."""
        raise ValueError("Nonce is too old")

    return nonce
validate_signature
validate_signature(public_key: str, signature: str, payload: Payload)

Validates a cryptographic signature for a given payload using a specified public key.

Source code in nearai/shared/near/sign.py
def validate_signature(public_key: str, signature: str, payload: Payload):
    """Validates a cryptographic signature for a given payload using a specified public key."""
    borsh_payload = BinarySerializer(dict(PAYLOAD_SCHEMA)).serialize(payload)
    to_sign = hashlib.sha256(borsh_payload).digest()
    real_signature = base64.b64decode(signature)

    verify_key: nacl.signing.VerifyKey = nacl.signing.VerifyKey(base58.b58decode(public_key[len(ED_PREFIX) :]))

    try:
        verify_key.verify(to_sign, real_signature)
        # print("Signature is valid.")
        return True
    except nacl.exceptions.BadSignatureError:
        # print("Signature was forged or corrupt.")
        return False
verify_access_key_owner
verify_access_key_owner(public_key, account_id) -> SignatureVerificationResult

Verifies if a given public key belongs to a specified account ID using FastNEAR API.

Source code in nearai/shared/near/sign.py
@mem_cache_with_timeout(300)
def verify_access_key_owner(public_key, account_id) -> SignatureVerificationResult:
    """Verifies if a given public key belongs to a specified account ID using FastNEAR API."""
    try:
        logger.info(f"Verifying access key owner for public key: {public_key}, account_id: {account_id}")
        url = f"https://api.fastnear.com/v0/public_key/{public_key}"
        response = requests.get(url)
        response.raise_for_status()
        content = response.json()
        account_ids = content.get("account_ids", [])
        key_owner_verified = account_id in account_ids
        if not key_owner_verified:
            logger.info("Key's owner verification failed. Only NEAR Mainnet accounts are supported.")
        return SignatureVerificationResult.from_bool(key_owner_verified)
    except requests.exceptions.HTTPError as http_err:
        logger.error(f"HTTP error occurred: {http_err}")
    except Exception as err:
        logger.error(f"Other error occurred: {err}")

    return SignatureVerificationResult.VERIFY_ACCESS_KEY_OWNER_SERVICE_NOT_AVAILABLE
verify_signed_message
verify_signed_message(account_id, public_key, signature, message, nonce, recipient, callback_url) -> SignatureVerificationResult

Verifies a signed message and ensures the public key belongs to the specified account.

Source code in nearai/shared/near/sign.py
def verify_signed_message(
    account_id, public_key, signature, message, nonce, recipient, callback_url
) -> SignatureVerificationResult:
    """Verifies a signed message and ensures the public key belongs to the specified account."""
    is_valid = validate_signature(public_key, signature, Payload(message, nonce, recipient, callback_url))

    if not is_valid and callback_url is not None:
        is_valid = validate_signature(public_key, signature, Payload(message, nonce, recipient, None))

    if is_valid:
        # verify that key belongs to `account_id`
        return verify_access_key_owner(public_key, account_id)

    # TODO verifies that key is a FULL ACCESS KEY

    return SignatureVerificationResult.FALSE

provider_models

ProviderModels
Source code in nearai/shared/provider_models.py
class ProviderModels:
    def __init__(self, config: ClientConfig) -> None:  # noqa: D107
        self._config = config

    @cached_property
    def provider_models(self) -> Dict[NamespacedName, Dict[str, str]]:
        """Returns a mapping canonical->provider->model_full_name."""
        client = self._config.get_hub_client()

        try:
            models = client.models.list()

            assert len(models.data) > 0, "No models found"
            result: Dict[NamespacedName, Dict[str, str]] = {}
            for model in models.data:
                provider, namespaced_model = get_provider_namespaced_model(model.id)
                namespaced_model = namespaced_model.canonical()
                if namespaced_model not in result:
                    result[namespaced_model] = {}
                if provider in result[namespaced_model]:
                    raise ValueError(f"Duplicate entry for provider {provider} and model {namespaced_model}")
                result[namespaced_model][provider] = model.id

            return result

        except requests.RequestException as e:
            raise RuntimeError(f"Error fetching models: {str(e)}") from e

    def available_provider_matches(self, model: NamespacedName) -> Dict[str, str]:
        """Returns provider matches for `model`."""
        return self.provider_models.get(model.canonical(), {})

    def match_provider_model(self, model: str, provider: Optional[str] = None) -> Tuple[str, str]:
        """Returns provider and model_full_path for given `model` and optional `provider`.

        `model` may take different formats. Supported ones:
        1. model_full_path, e.g. "fireworks::accounts/yi-01-ai/models/yi-large"
        2. model_full_path without provider, e.g. "accounts/yi-01-ai/models/yi-large"
        3. model_short_name as used by provider, e.g. "llama-v3-70b-instruct"
        4. namespace/model_short_name as used by provider, e.g. "yi-01-ai/yi-large"
        5. model_name as used in registry, e.g. "llama-3-70b-instruct"
        6. namespace/model_name as used in registry, e.g. "near.ai/llama-3-70b-instruct"
        """
        if provider == "":
            provider = None
        matched_provider, namespaced_model = get_provider_namespaced_model(model, provider)
        namespaced_model = namespaced_model.canonical()
        if namespaced_model not in self.provider_models:
            raise ValueError(f"Model {namespaced_model} not present in provider models {self.provider_models}")
        available_matches = self.provider_models[namespaced_model]
        if matched_provider not in available_matches:
            for match in available_matches.keys():
                matched_provider = match
                break
        if provider and provider != matched_provider:
            raise ValueError(
                f"Requested provider {provider} for model {model} does not match matched_provider {matched_provider}"
            )
        return matched_provider, available_matches[matched_provider]

    def get_unregistered_common_provider_models(
        self, registry_models: Dict[NamespacedName, NamespacedName]
    ) -> List[Dict[str, str]]:
        """Returns provider matches for unregistered provider models with default namespace."""
        result: List[Dict[str, str]] = []
        for namespaced_name, available_matches in self.provider_models.items():
            if namespaced_name.namespace != "" or namespaced_name in registry_models:
                continue
            result.append(available_matches)
        return result
provider_models cached property
provider_models: Dict[NamespacedName, Dict[str, str]]

Returns a mapping canonical->provider->model_full_name.

available_provider_matches
available_provider_matches(model: NamespacedName) -> Dict[str, str]

Returns provider matches for model.

Source code in nearai/shared/provider_models.py
def available_provider_matches(self, model: NamespacedName) -> Dict[str, str]:
    """Returns provider matches for `model`."""
    return self.provider_models.get(model.canonical(), {})
get_unregistered_common_provider_models
get_unregistered_common_provider_models(registry_models: Dict[NamespacedName, NamespacedName]) -> List[Dict[str, str]]

Returns provider matches for unregistered provider models with default namespace.

Source code in nearai/shared/provider_models.py
def get_unregistered_common_provider_models(
    self, registry_models: Dict[NamespacedName, NamespacedName]
) -> List[Dict[str, str]]:
    """Returns provider matches for unregistered provider models with default namespace."""
    result: List[Dict[str, str]] = []
    for namespaced_name, available_matches in self.provider_models.items():
        if namespaced_name.namespace != "" or namespaced_name in registry_models:
            continue
        result.append(available_matches)
    return result
match_provider_model
match_provider_model(model: str, provider: Optional[str] = None) -> Tuple[str, str]

Returns provider and model_full_path for given model and optional provider.

model may take different formats. Supported ones: 1. model_full_path, e.g. "fireworks::accounts/yi-01-ai/models/yi-large" 2. model_full_path without provider, e.g. "accounts/yi-01-ai/models/yi-large" 3. model_short_name as used by provider, e.g. "llama-v3-70b-instruct" 4. namespace/model_short_name as used by provider, e.g. "yi-01-ai/yi-large" 5. model_name as used in registry, e.g. "llama-3-70b-instruct" 6. namespace/model_name as used in registry, e.g. "near.ai/llama-3-70b-instruct"

Source code in nearai/shared/provider_models.py
def match_provider_model(self, model: str, provider: Optional[str] = None) -> Tuple[str, str]:
    """Returns provider and model_full_path for given `model` and optional `provider`.

    `model` may take different formats. Supported ones:
    1. model_full_path, e.g. "fireworks::accounts/yi-01-ai/models/yi-large"
    2. model_full_path without provider, e.g. "accounts/yi-01-ai/models/yi-large"
    3. model_short_name as used by provider, e.g. "llama-v3-70b-instruct"
    4. namespace/model_short_name as used by provider, e.g. "yi-01-ai/yi-large"
    5. model_name as used in registry, e.g. "llama-3-70b-instruct"
    6. namespace/model_name as used in registry, e.g. "near.ai/llama-3-70b-instruct"
    """
    if provider == "":
        provider = None
    matched_provider, namespaced_model = get_provider_namespaced_model(model, provider)
    namespaced_model = namespaced_model.canonical()
    if namespaced_model not in self.provider_models:
        raise ValueError(f"Model {namespaced_model} not present in provider models {self.provider_models}")
    available_matches = self.provider_models[namespaced_model]
    if matched_provider not in available_matches:
        for match in available_matches.keys():
            matched_provider = match
            break
    if provider and provider != matched_provider:
        raise ValueError(
            f"Requested provider {provider} for model {model} does not match matched_provider {matched_provider}"
        )
    return matched_provider, available_matches[matched_provider]
get_provider_model
get_provider_model(provider: Optional[str], model: str) -> Tuple[Optional[str], str]

Splits the model string based on a predefined separator and returns the components.


provider (Optional[str]): The default provider name. Can be `None` if the provider
                          is included in the `model` string.
model (str): The model identifier, which may include the provider name separated by
             a specific delimiter (defined by `PROVIDER_MODEL_SEP`, e.g. `::`).
Source code in nearai/shared/provider_models.py
def get_provider_model(provider: Optional[str], model: str) -> Tuple[Optional[str], str]:
    """Splits the `model` string based on a predefined separator and returns the components.

    Args:
    ----
        provider (Optional[str]): The default provider name. Can be `None` if the provider
                                  is included in the `model` string.
        model (str): The model identifier, which may include the provider name separated by
                     a specific delimiter (defined by `PROVIDER_MODEL_SEP`, e.g. `::`).

    """
    if PROVIDER_MODEL_SEP in model:
        parts = model.split(PROVIDER_MODEL_SEP)
        assert len(parts) == 2
        return parts[0], parts[1]
    return provider, model
get_provider_namespaced_model
get_provider_namespaced_model(provider_model: str, provider: Optional[str] = None) -> Tuple[str, NamespacedName]

Given provider_model returns provider and namespaced model.

Source code in nearai/shared/provider_models.py
def get_provider_namespaced_model(provider_model: str, provider: Optional[str] = None) -> Tuple[str, NamespacedName]:
    """Given `provider_model` returns provider and namespaced model."""
    provider_model = provider_model.replace("accounts/", "")
    provider_model = provider_model.replace("fireworks/", "")
    provider_model = provider_model.replace("models/", "")
    provider_opt, model = get_provider_model(DEFAULT_PROVIDER if not provider else provider, provider_model)
    provider = cast(str, provider_opt)
    if provider == "hyperbolic":
        model = re.sub(r".*/", "", model)
        return provider, NamespacedName(model)
    if provider == "fireworks":
        parts = model.split("/")
        if len(parts) == 1:
            return provider, NamespacedName(name=parts[0])
        elif len(parts) == 2:
            return provider, NamespacedName(namespace=parts[0], name=parts[1])
        else:
            raise ValueError(f"Invalid model format for Fireworks: {model}")
    if provider == "local":
        model = re.sub(r".*/", "", model)
        return provider, NamespacedName(name=model)
    raise ValueError(f"Unrecognized provider: {provider}")

solvers

DDOTSV0Solver

Bases: SolverStrategy

Solver strategy for competitive programming problems live on DDOTS.

This dataset will run agents in an Agent environment previously prepared.

workspace/ .id -- Id of the problem PROBLEM.txt -- Description of the problem

The agent should call env.submit_python(code) to submit the code to the DDOTS server.

Source code in nearai/solvers/ddot_v0_solver.py
class DDOTSV0Solver(SolverStrategy):
    """Solver strategy for competitive programming problems live on DDOTS.

    This dataset will run agents in an Agent environment previously prepared.

    workspace/
        .id             -- Id of the problem
        PROBLEM.txt     -- Description of the problem

    The agent should call env.submit_python(code) to submit the code to the DDOTS server.

    """

    def __init__(self, dataset_ref: Dataset, agents: str, max_iterations: int, save_snapshots: bool = False):  # noqa: D107
        client_config = ClientConfig(
            base_url=CONFIG.nearai_hub.base_url,
            auth=CONFIG.auth,
        )
        self.agents = [Agent.load_agent(agent, client_config) for agent in agents.split(",")]
        self.max_iterations = max_iterations

        date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        rnd_id = random.randint(10**8, 10**9 - 1)
        self._saved_trajectories = DATA_FOLDER / "data" / "ddots_v0_trajectories" / f"{date}_{rnd_id}"
        self._saved_trajectories.mkdir(parents=True, exist_ok=True)

        self.save_snapshots = save_snapshots
        print("Saving trajectories to", self._saved_trajectories)

    def evaluation_name(self) -> str:  # noqa: D102
        return "ddots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["ddots_codeforces_small/v0", "datasets/ddots_codeforces_medium_A_B/v0"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        problem_id = datum["problem_id"]
        description = datum["description"]

        client_config = ClientConfig(
            base_url=CONFIG.nearai_hub.base_url,
            auth=CONFIG.auth,
        )
        client = InferenceClient(client_config)
        env = DDOTSEnvironment(self.agents, problem_id, description, client)
        env.write_file(".solved", str(False))

        try:
            env.run(description, max_iterations=self.max_iterations)
            env.write_file(".solved", str(env.solved))

        except Exception as e:
            print(f"Error running task: {e}")

        finally:
            if self.save_snapshots:
                snapshot = env.create_snapshot()
                with open(self._saved_trajectories / f"{problem_id}.tar.gz", "wb") as f:
                    f.write(snapshot)

        return env.solved

GSM8KSolverStrategy

Bases: SolverStrategy

Solver strategy for the GSM8K dataset.

Source code in nearai/solvers/gsm8k_solver.py
class GSM8KSolverStrategy(SolverStrategy):
    """Solver strategy for the GSM8K dataset."""

    SHOTS = 8

    def __init__(self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "") -> None:  # noqa: D107
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref

    def evaluation_name(self) -> str:  # noqa: D102
        return "gsm8k"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["gsm8k"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        parsed_datum: GSM8KDatum = GSM8KDatum(**datum)

        problem_shots_indices = list(range(0, self.SHOTS))
        problem_shots = list(
            map(
                lambda i: GSM8KDatum(**self.dataset_ref["train"][i]).model_dump(),
                problem_shots_indices,
            )
        )

        session = self.start_inference_session("")
        session.add_system_message(
            dedent(
                """
                    You are a helpful assistant. You're goal is to answer word based math questions.
                    """
                + "\n\n"
                + "Here are some examples of math questions and their answers:"
                + "\n\n".join([f"Question: {shot['question']}\nAnswer: {shot['answer']}" for shot in problem_shots])
                + "\n\n"
                + "Now, answer the next question provided in the user prompt. "
                + "Think step by step about how to solve the problem. "
                + "Then, provide the answer."
            )
        )
        res_output = session.run_task(parsed_datum.question).strip()

        ## cleanup the output
        session = self.start_inference_session("")
        res_refined_output = session.run_task(
            dedent(
                f"""
                    You are a helpful assistant. You're goal is to answer math questions.

                    You have just answered a math question with the following response:

                    --- BEGIN RESPONSE ---
                    {res_output}
                    --- END RESPONSE ---

                    Please refine your answer.

                    Only output the final number *without units* as your answer. Nothing else.
                    """
            )
        ).strip()
        res_refined_output = res_refined_output.replace("$", "").replace(",", "")
        if " " in res_refined_output:
            res_refined_output = res_refined_output.split(" ")[0]
        try:
            res_refined_output = str(int(res_refined_output))
        except Exception:
            pass
        try:
            res_refined_output = str(int(float(res_refined_output)))
        except Exception:
            pass

        refined_answer = parsed_datum.answer.replace("$", "").replace(",", "")
        print(res_refined_output, refined_answer)
        return res_refined_output == refined_answer

HellaswagSolverStrategy

Bases: SolverStrategy

Solver strategy for the MMLU dataset.

Source code in nearai/solvers/hellaswag_solver.py
class HellaswagSolverStrategy(SolverStrategy):
    """Solver strategy for the MMLU dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 8
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        return f"hellaswag_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["hellaswag"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = HellaswagDatum(**datum).model_dump()

        choices = ["A", "B", "C", "D"]
        example_problems_indices = list(range(0, 5 * self.shots, 5))
        example_problems = list(
            map(
                lambda d: HellaswagDatum(**d).model_dump(),
                [self.dataset_ref["validation"][i] for i in example_problems_indices],
            )
        )
        base_prompt = Template(
            open(PROMPTS_FOLDER / "hellaswag_verbose_answer.j2").read(),
            trim_blocks=True,
        ).render(
            example_problems=example_problems,
            challenge_problem=datum,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "hellaswag_extract_answer.j2").read(),
            trim_blocks=True,
        ).render(
            challenge_problem=datum,
            answer_text=response,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(extract_answer_prompt)

        try:
            answer = choices.index(response)
            return bool(answer == int(datum["label"]))
        except Exception:
            print("Failed to parse answer")
            return False

LeanSolverStrategy

Bases: SolverStrategy

Solver strategy to evaluate against Lean problems.

Source code in nearai/solvers/lean_solver.py
class LeanSolverStrategy(SolverStrategy):
    """Solver strategy to evaluate against Lean problems."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = ""
    ) -> None:
        super().__init__(model, agent)

    def evaluation_name(self) -> str:  # noqa: D102
        assert self.dataset_evaluation_name
        return self.dataset_evaluation_name

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["lean"]

    def solve(self, datum: dict) -> Tuple[bool, dict]:  # noqa: D102
        lean_datum = LeanDatum.model_validate(datum)
        lean_datum.url = load_repository(lean_datum.url)

        info: dict = {}
        info["verbose"] = {}

        lean_task = LeanTaskInfo(
            lean_datum.url,
            lean_datum.commit,
            lean_datum.filename,
            lean_datum.theorem,
            load_theorem(lean_datum),
        )
        info["verbose"]["theorem_raw"] = lean_task.theorem_raw

        base_prompt = Template(open(PROMPTS_FOLDER / "lean_answer.j2").read(), trim_blocks=True).render(
            url=lean_task.url,
            commit=lean_task.commit,
            filepath=lean_task.filename,
            theorem_name=lean_task.theorem,
            theorem_raw=lean_task.theorem_raw,
            begin_marker=BEGIN_MARKER,
            end_marker=END_MARKER,
        )
        response = self.start_inference_session("").run_task(base_prompt)

        json_response = extract_between_markers(response)
        if not json_response:
            info["error"] = "Failed to extract between markers."
            info["verbose"]["response"] = response
            return False, info

        tactics = parse_tactics(json_response)
        if not tactics:
            info["error"] = "Failed to parse tactics."
            info["verbose"]["response"] = json_response
            return False, info

        # Sometimes, there are timeout errors.
        num_attempts = 3
        info["tactics"] = tactics
        for i in range(0, num_attempts):
            if i != 0:
                info["check_solution_attempts"] = f"{i+1} (max: {num_attempts})"
            try:
                r, m = check_solution(lean_datum, tactics)
                if r:
                    info["verbose"]["check_solution_message"] = m
                else:
                    info["check_solution_message"] = m
                return r, info
            except Exception as e:
                if i == num_attempts - 1:
                    error_message = f"Exception while checking solution: {str(e)}."
                    print(error_message)
                    info["error"] = error_message
        return False, info

LiveBenchSolverStrategy

Bases: SolverStrategy

Solver strategy for the live bench dataset.

Source code in nearai/solvers/livebench_solver.py
class LiveBenchSolverStrategy(SolverStrategy):
    """Solver strategy for the live bench dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: str, model: str = "", agent: str = "", step: str = "all"
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.step = step

    def evaluation_name(self) -> str:  # noqa: D102
        return "live_bench"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["live_bench"]

    def get_custom_tasks(self) -> List[dict]:  # noqa: D102
        return [{"summary": "all"}]

    @property
    def evaluated_entry_name(self) -> str:  # noqa: D102
        name = ""
        if self.agent:
            name = self.agent_name()
            if self.model_name != "":
                name += f"_with_model_{self.model_name}"
        else:
            name = self.model_name
        assert "/" not in name
        return name.lower()

    @SolverStrategyClassProperty
    def scoring_method(self) -> SolverScoringMethod:  # noqa: D102
        return SolverScoringMethod.Custom

    def solve(self, _datum: dict) -> Tuple[bool, dict]:  # noqa: D102
        if self.step == "gen_model_answer":
            self.gen_model_answer()
            return True, {}
        if self.step == "gen_ground_truth_judgement":
            return self.gen_ground_truth_judgement(), {}
        if self.step == "show_livebench_results":
            return self.show_livebench_results()
        if self.step == "all":
            self.gen_model_answer()
            if not self.gen_ground_truth_judgement():
                return False, {}
            return self.show_livebench_results()
        return False, {}

    def gen_model_answer(self) -> None:  # noqa: D102
        print("")
        print("----------- Step gen_model_answer -----------")
        print("")
        list_of_question_files = glob.glob(f"{self.dataset_ref}/**/question.jsonl", recursive=True)
        for question_file in list_of_question_files:
            questions = load_questions_jsonl(question_file)
            bench_name = os.path.dirname(question_file).split(str(self.dataset_ref))[-1]
            answer_file = _get_answer_file_path(bench_name, self.evaluated_entry_name)
            print(f"Questions from {question_file}")
            print(f"Output to {answer_file}")
            self.run_eval(questions, answer_file)

    def run_eval(self, questions, answer_file) -> None:  # noqa: D102
        answer_file = os.path.expanduser(answer_file)

        # Load existing answers
        existing_answers = set()
        if os.path.exists(answer_file):
            print(
                f"Answer file {answer_file} exists. Will skip already answered questions. Delete this file if that is not intended."  # noqa: E501
            )
            with open(answer_file, "r") as fin:
                for line in fin:
                    answer = json.loads(line)
                    existing_answers.add(answer["question_id"])

        for question in tqdm(questions):
            if question["question_id"] in existing_answers:
                continue
            choices = self.answer_question(question)

            ans_json = {
                "question_id": question["question_id"],
                "answer_id": shortuuid.uuid(),
                "model_id": self.evaluated_entry_name,
                "choices": choices,
                "tstamp": time.time(),
            }

            os.makedirs(os.path.dirname(answer_file), exist_ok=True)
            with open(answer_file, "a") as fout:
                fout.write(json.dumps(ans_json) + "\n")

    def answer_question(self, question) -> List[dict]:  # noqa: D102
        turns = []
        session = self.start_inference_session(question["question_id"])
        for qs in question["turns"]:
            output = session.run_task(qs)
            turns.append(output)

        return [{"index": 0, "turns": turns}]

    def gen_ground_truth_judgement(self) -> bool:  # noqa: D102
        print("")
        print("----------- Step gen_ground_truth_judgement -----------")
        print("")
        script_path = "nearai/projects/live_bench/gen_ground_truth_judgement.sh"

        try:
            # Run the script without capturing output
            subprocess.run(["/bin/bash", script_path, self.evaluated_entry_name, self.dataset_ref], check=True)
            return True

        except subprocess.CalledProcessError as e:
            print(f"An error occurred while running the script: {e}")
            return False

    def show_livebench_results(self) -> Tuple[bool, dict]:  # noqa: D102
        print("")
        print("----------- Step show_livebench_results -----------")
        print("")
        script_path = "nearai/projects/live_bench/show_livebench_results.sh"

        try:
            # Run the script without capturing output
            subprocess.run(["/bin/bash", script_path, self.evaluated_entry_name], check=True)

        except subprocess.CalledProcessError as e:
            print(f"An error occurred while running the script: {e}")
            return False, {}

        return self.create_result_dict()

    def read_csv_to_dict(self, file_path) -> dict:  # noqa: D102
        file_path = os.path.expanduser(file_path)
        with open(file_path, "r") as f:
            reader = csv.DictReader(f)
            matching_rows = [row for row in reader if row["model"] == self.evaluated_entry_name]
            return matching_rows[-1] if matching_rows else {}  # Get the last matching row

    def create_result_dict(self) -> Tuple[bool, dict]:  # noqa: D102
        tasks_data = self.read_csv_to_dict(_get_all_tasks_csv_file())
        groups_data = self.read_csv_to_dict(_get_all_groups_csv_file())

        if not tasks_data or not groups_data:
            return False, {}  # Return None if the model is not found in either file

        result: dict = {"tasks": {}, "groups": {}}

        for key, value in tasks_data.items():
            if key != "model":
                result["tasks"][key] = float(value)

        for key, value in groups_data.items():
            if key != "model":
                result["groups"][key] = float(value)

        return True, result

    def get_evaluation_metrics(self, tasks_results: List[Tuple[bool, Any]]) -> Dict[str, Any]:  # noqa: D102
        results: Dict[str, Dict[str, Any]] = tasks_results[-1][1]
        if len(results) == 0:
            raise ValueError("Cache empty. Rerun the job with --force. Use --step arg to specify a step.")
        metrics: Dict[str, Any] = {"average": results["groups"]["average"]}

        for group, score in results["groups"].items():
            if group == "average":
                continue
            metrics[f"group/{group}"] = score

        for task, score in results["tasks"].items():
            metrics[f"task/{task}"] = score

        return metrics

MBPPSolverStrategy

Bases: SolverStrategy

Solver strategy for the MBPP dataset.

Source code in nearai/solvers/mbpp_solver.py
class MBPPSolverStrategy(SolverStrategy):
    """Solver strategy for the MBPP dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 3
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        prefix = self.dataset_evaluation_name if self.dataset_evaluation_name else "mbpp"
        return f"{prefix}_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["mbpp"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = MBPPDatum(**datum).model_dump()

        ## Allow LLM to think "out loud" for it's answer
        function_name = get_function_name(datum["code"])
        example_problems = list(islice(self.dataset_ref["prompt"], self.shots))
        base_prompt = Template(open(PROMPTS_FOLDER / "mbpp_verbose_answer.j2").read(), trim_blocks=True).render(
            function_name=function_name,
            example_problems=example_problems,
            challenge_problem=datum,
        )
        response = self.start_inference_session(str(datum["task_id"])).run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "mbpp_extract_answer.j2").read(), trim_blocks=True
        ).render(
            function_name=function_name,
            answer_text=response,
        )
        response = self.start_inference_session(str(datum["task_id"])).run_task(extract_answer_prompt)

        ## Parse the python code
        python_code_blocks = parse_python_code_block(response) + parse_code_block(response)
        code = ""
        if len(python_code_blocks) == 0:
            code = response
        else:
            code = python_code_blocks[0]

        ## Evaluate the code
        try:
            for test in datum["test_list"] + datum["challenge_test_list"]:
                test_code = code + "\n" + test
                if not run_with_timeout(test_code):
                    return False
            return True
        except Exception:
            return False

MMLUSolverStrategy

Bases: SolverStrategy

Solver strategy for the MMLU dataset.

Source code in nearai/solvers/mmlu_solver.py
class MMLUSolverStrategy(SolverStrategy):
    """Solver strategy for the MMLU dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 8
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        prefix = self.dataset_evaluation_name if self.dataset_evaluation_name else "mmlu"
        return f"{prefix}_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["mmlu"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = MMLUDatum(**datum).model_dump()

        choices = ["A", "B", "C", "D"]
        example_problems_indices = list(range(0, 5 * self.shots, 5))
        example_problems = list(
            map(
                lambda d: MMLUDatum(**d).model_dump(),
                [self.dataset_ref["dev"][i] for i in example_problems_indices],
            )
        )
        base_prompt = Template(open(PROMPTS_FOLDER / "mmlu_verbose_answer.j2").read(), trim_blocks=True).render(
            example_problems=example_problems,
            challenge_problem=datum,
            choices=choices,
        )

        response = self.start_inference_session("").run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "mmlu_extract_answer.j2").read(), trim_blocks=True
        ).render(
            challenge_problem=datum,
            answer_text=response,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(extract_answer_prompt)

        try:
            answer = choices.index(response)
            return bool(answer == datum["answer"])
        except Exception:
            print("Failed to parse answer")
            return False

SolverStrategy

Bases: ABC

Abstract class for solver strategies.

Source code in nearai/solvers/__init__.py
class SolverStrategy(ABC, metaclass=SolverStrategyMeta):
    """Abstract class for solver strategies."""

    def __init__(self, model: str = "", agent: str = "") -> None:
        CONFIG.confirm_commands = False
        self.client_config = CONFIG.get_client_config()
        self.client = InferenceClient(self.client_config)
        assert model != "" or agent != ""
        self.dataset_evaluation_name = ""

        self.provider = ""
        self.model_namespace = ""
        self.model_full_path = ""
        self.model_name = ""
        if model != "":
            self.provider, self.model_full_path = self.client.provider_models.match_provider_model(model)
            self.provider, namespaced_model = get_provider_namespaced_model(self.model_full_path, self.provider)
            self.model_namespace = namespaced_model.namespace
            self.model_name = namespaced_model.name

        self.agent = agent
        self.agent_params = {
            "api_url": CONFIG.api_url,
            "data_source": "local_files",
            "temperature": 0.0,
            "record_run": False,
            "verbose": False,
            "change_to_agent_temp_dir": False,
        }
        if self.model_full_path:
            self.agent_params["model"] = self.model_full_path

    @property
    def name(self) -> str:
        """Returns the name of the solver strategy."""
        return type(self).__name__

    @SolverStrategyClassProperty
    def scoring_method(self) -> SolverScoringMethod:
        return SolverScoringMethod.TrueOrFalseList

    @abstractmethod
    def evaluation_name(self) -> str:
        """Returns a unique name for (benchmark, solver) tuple, e.g. 'mbpp' or 'live_bench' or 'mmlu-5-shot'."""
        ...

    @abstractmethod
    def compatible_datasets(self) -> List[str]:
        """Returns the list of datasets that the solver strategy is compatible with."""
        ...

    def agent_name(self) -> str:
        """Returns agent name that is evaluated."""
        if not self.agent:
            return ""
        path = Path(self.agent)
        return path.parent.name

    def agent_version(self) -> str:
        """Returns agent name that is evaluated."""
        if not self.agent:
            return ""
        path = Path(self.agent)
        return path.name

    def evaluated_entry_namespace(self) -> str:
        """Returns namespace of a model or agent to be evaluated."""
        if self.agent:
            path = Path(self.agent)
            return path.parent.parent.name
        return self.model_namespace

    def model_provider(self) -> str:
        """Returns model provider."""
        if self.provider != "":
            return self.provider
        if self.agent != "":
            agent_obj = Agent.load_agent(self.agent, self.client_config)
            return agent_obj.model_provider
        return ""

    @abstractmethod
    def solve(self, datum: dict) -> Union[bool, Tuple[bool, Any]]:
        """Solves the task for the given datum."""
        ...

    def get_custom_tasks(self) -> List[dict]:
        """Custom tasks for custom benchmark."""
        if self.scoring_method == SolverScoringMethod.Custom:
            raise NotImplementedError("get_custom_tasks must be implemented for Custom scoring method")
        else:
            raise AttributeError("get_custom_tasks is only applicable for Custom scoring method")

    def get_evaluation_metrics(self, tasks_results: List[Tuple[bool, Any]]) -> Dict[str, Any]:
        """Given results for all datums, returns evaluation metrics.

        Not used by TrueOrFalseList scoring method.
        Do not prepend with evaluation_name. If hierarchical, use slashes /.
        Expected metrics is a dict of scores, e.g.: {"average": <val>, "group/coding": <val>}.
        """
        raise NotImplementedError("get_evaluation_metrics not implemented")

    def start_inference_session(self, task_id: str) -> SolverInferenceSession:
        return SolverInferenceSession(
            self.agent, self.agent_params, self.model_full_path, self.client, self.evaluation_name()
        ).start_inference_session(task_id)
name property
name: str

Returns the name of the solver strategy.

agent_name
agent_name() -> str

Returns agent name that is evaluated.

Source code in nearai/solvers/__init__.py
def agent_name(self) -> str:
    """Returns agent name that is evaluated."""
    if not self.agent:
        return ""
    path = Path(self.agent)
    return path.parent.name
agent_version
agent_version() -> str

Returns agent name that is evaluated.

Source code in nearai/solvers/__init__.py
def agent_version(self) -> str:
    """Returns agent name that is evaluated."""
    if not self.agent:
        return ""
    path = Path(self.agent)
    return path.name
compatible_datasets abstractmethod
compatible_datasets() -> List[str]

Returns the list of datasets that the solver strategy is compatible with.

Source code in nearai/solvers/__init__.py
@abstractmethod
def compatible_datasets(self) -> List[str]:
    """Returns the list of datasets that the solver strategy is compatible with."""
    ...
evaluated_entry_namespace
evaluated_entry_namespace() -> str

Returns namespace of a model or agent to be evaluated.

Source code in nearai/solvers/__init__.py
def evaluated_entry_namespace(self) -> str:
    """Returns namespace of a model or agent to be evaluated."""
    if self.agent:
        path = Path(self.agent)
        return path.parent.parent.name
    return self.model_namespace
evaluation_name abstractmethod
evaluation_name() -> str

Returns a unique name for (benchmark, solver) tuple, e.g. 'mbpp' or 'live_bench' or 'mmlu-5-shot'.

Source code in nearai/solvers/__init__.py
@abstractmethod
def evaluation_name(self) -> str:
    """Returns a unique name for (benchmark, solver) tuple, e.g. 'mbpp' or 'live_bench' or 'mmlu-5-shot'."""
    ...
get_custom_tasks
get_custom_tasks() -> List[dict]

Custom tasks for custom benchmark.

Source code in nearai/solvers/__init__.py
def get_custom_tasks(self) -> List[dict]:
    """Custom tasks for custom benchmark."""
    if self.scoring_method == SolverScoringMethod.Custom:
        raise NotImplementedError("get_custom_tasks must be implemented for Custom scoring method")
    else:
        raise AttributeError("get_custom_tasks is only applicable for Custom scoring method")
get_evaluation_metrics
get_evaluation_metrics(tasks_results: List[Tuple[bool, Any]]) -> Dict[str, Any]

Given results for all datums, returns evaluation metrics.

Not used by TrueOrFalseList scoring method. Do not prepend with evaluation_name. If hierarchical, use slashes /. Expected metrics is a dict of scores, e.g.: {"average": , "group/coding": }.

Source code in nearai/solvers/__init__.py
def get_evaluation_metrics(self, tasks_results: List[Tuple[bool, Any]]) -> Dict[str, Any]:
    """Given results for all datums, returns evaluation metrics.

    Not used by TrueOrFalseList scoring method.
    Do not prepend with evaluation_name. If hierarchical, use slashes /.
    Expected metrics is a dict of scores, e.g.: {"average": <val>, "group/coding": <val>}.
    """
    raise NotImplementedError("get_evaluation_metrics not implemented")
model_provider
model_provider() -> str

Returns model provider.

Source code in nearai/solvers/__init__.py
def model_provider(self) -> str:
    """Returns model provider."""
    if self.provider != "":
        return self.provider
    if self.agent != "":
        agent_obj = Agent.load_agent(self.agent, self.client_config)
        return agent_obj.model_provider
    return ""
solve abstractmethod
solve(datum: dict) -> Union[bool, Tuple[bool, Any]]

Solves the task for the given datum.

Source code in nearai/solvers/__init__.py
@abstractmethod
def solve(self, datum: dict) -> Union[bool, Tuple[bool, Any]]:
    """Solves the task for the given datum."""
    ...

SolverStrategyMeta

Bases: ABCMeta

Metaclass that automatically registers subclasses in the SolverStrategyRegistry.

Source code in nearai/solvers/__init__.py
class SolverStrategyMeta(ABCMeta):
    """Metaclass that automatically registers subclasses in the SolverStrategyRegistry."""

    def __new__(cls, name: str, bases: tuple, namespace: dict) -> Any:
        new_class = super().__new__(cls, name, bases, namespace)
        if bases != (ABC,):  # Avoid registering the abstract base class itself
            SolverStrategyRegistry[new_class.__name__] = new_class  # type: ignore
        return new_class

ddot_v0_solver

DDOTSEnvironment

Bases: Environment

Source code in nearai/solvers/ddot_v0_solver.py
class DDOTSEnvironment(Environment):
    def __init__(self, agents: List[Agent], problem_id: str, description: str, client):  # noqa: D107
        self.tdir = TemporaryDirectory()
        self.hub_client = get_hub_client()
        thread = self.hub_client.beta.threads.create()
        super().__init__(
            self.tdir.name,
            agents,
            client,
            self.hub_client,
            thread.id,
            "todo",
            approvals={"confirm_execution": lambda _: False},
        )

        self.problem_id = problem_id
        self.solved = False

        files = {
            ".id": problem_id,
            "PROBLEM.txt": description,
            "solution.py": "",
            "test.in": "",
            "test.sh": "#!/bin/bash\npython3 solution.py < test.in",
        }
        for fname, content in files.items():
            with open(self.tdir.name + "/" + fname, "w") as f:
                f.write(content)

    async def async_submit(self, code: str) -> Tuple[bool, str]:  # noqa: D102
        submission_id = await submit_problem(self.problem_id, code, Extensions.PYTHON)

        try:
            await is_output_ready(submission_id)
        except Exception:
            print("WARNING: Submission took too long to execute on DDOTS")
            self.mark_done()
            return False, "Submission took too long to execute on the platform"

        ok = await submission_accepted(submission_id)

        if ok:
            self.solved = True
            self.mark_done()
            return True, ""

        output = await get_output(submission_id)

        return False, output

    def submit_python(self, code: str) -> Tuple[bool, str]:
        """Returns True if the submission was accepted, False otherwise.

        The second element of the tuple is the output of the checker if the submission was rejected.
        """
        return asyncio.run(self.async_submit(code))
submit_python
submit_python(code: str) -> Tuple[bool, str]

Returns True if the submission was accepted, False otherwise.

The second element of the tuple is the output of the checker if the submission was rejected.

Source code in nearai/solvers/ddot_v0_solver.py
def submit_python(self, code: str) -> Tuple[bool, str]:
    """Returns True if the submission was accepted, False otherwise.

    The second element of the tuple is the output of the checker if the submission was rejected.
    """
    return asyncio.run(self.async_submit(code))
DDOTSV0Solver

Bases: SolverStrategy

Solver strategy for competitive programming problems live on DDOTS.

This dataset will run agents in an Agent environment previously prepared.

workspace/ .id -- Id of the problem PROBLEM.txt -- Description of the problem

The agent should call env.submit_python(code) to submit the code to the DDOTS server.

Source code in nearai/solvers/ddot_v0_solver.py
class DDOTSV0Solver(SolverStrategy):
    """Solver strategy for competitive programming problems live on DDOTS.

    This dataset will run agents in an Agent environment previously prepared.

    workspace/
        .id             -- Id of the problem
        PROBLEM.txt     -- Description of the problem

    The agent should call env.submit_python(code) to submit the code to the DDOTS server.

    """

    def __init__(self, dataset_ref: Dataset, agents: str, max_iterations: int, save_snapshots: bool = False):  # noqa: D107
        client_config = ClientConfig(
            base_url=CONFIG.nearai_hub.base_url,
            auth=CONFIG.auth,
        )
        self.agents = [Agent.load_agent(agent, client_config) for agent in agents.split(",")]
        self.max_iterations = max_iterations

        date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        rnd_id = random.randint(10**8, 10**9 - 1)
        self._saved_trajectories = DATA_FOLDER / "data" / "ddots_v0_trajectories" / f"{date}_{rnd_id}"
        self._saved_trajectories.mkdir(parents=True, exist_ok=True)

        self.save_snapshots = save_snapshots
        print("Saving trajectories to", self._saved_trajectories)

    def evaluation_name(self) -> str:  # noqa: D102
        return "ddots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["ddots_codeforces_small/v0", "datasets/ddots_codeforces_medium_A_B/v0"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        problem_id = datum["problem_id"]
        description = datum["description"]

        client_config = ClientConfig(
            base_url=CONFIG.nearai_hub.base_url,
            auth=CONFIG.auth,
        )
        client = InferenceClient(client_config)
        env = DDOTSEnvironment(self.agents, problem_id, description, client)
        env.write_file(".solved", str(False))

        try:
            env.run(description, max_iterations=self.max_iterations)
            env.write_file(".solved", str(env.solved))

        except Exception as e:
            print(f"Error running task: {e}")

        finally:
            if self.save_snapshots:
                snapshot = env.create_snapshot()
                with open(self._saved_trajectories / f"{problem_id}.tar.gz", "wb") as f:
                    f.write(snapshot)

        return env.solved

gsm8k_solver

GSM8KSolverStrategy

Bases: SolverStrategy

Solver strategy for the GSM8K dataset.

Source code in nearai/solvers/gsm8k_solver.py
class GSM8KSolverStrategy(SolverStrategy):
    """Solver strategy for the GSM8K dataset."""

    SHOTS = 8

    def __init__(self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "") -> None:  # noqa: D107
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref

    def evaluation_name(self) -> str:  # noqa: D102
        return "gsm8k"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["gsm8k"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        parsed_datum: GSM8KDatum = GSM8KDatum(**datum)

        problem_shots_indices = list(range(0, self.SHOTS))
        problem_shots = list(
            map(
                lambda i: GSM8KDatum(**self.dataset_ref["train"][i]).model_dump(),
                problem_shots_indices,
            )
        )

        session = self.start_inference_session("")
        session.add_system_message(
            dedent(
                """
                    You are a helpful assistant. You're goal is to answer word based math questions.
                    """
                + "\n\n"
                + "Here are some examples of math questions and their answers:"
                + "\n\n".join([f"Question: {shot['question']}\nAnswer: {shot['answer']}" for shot in problem_shots])
                + "\n\n"
                + "Now, answer the next question provided in the user prompt. "
                + "Think step by step about how to solve the problem. "
                + "Then, provide the answer."
            )
        )
        res_output = session.run_task(parsed_datum.question).strip()

        ## cleanup the output
        session = self.start_inference_session("")
        res_refined_output = session.run_task(
            dedent(
                f"""
                    You are a helpful assistant. You're goal is to answer math questions.

                    You have just answered a math question with the following response:

                    --- BEGIN RESPONSE ---
                    {res_output}
                    --- END RESPONSE ---

                    Please refine your answer.

                    Only output the final number *without units* as your answer. Nothing else.
                    """
            )
        ).strip()
        res_refined_output = res_refined_output.replace("$", "").replace(",", "")
        if " " in res_refined_output:
            res_refined_output = res_refined_output.split(" ")[0]
        try:
            res_refined_output = str(int(res_refined_output))
        except Exception:
            pass
        try:
            res_refined_output = str(int(float(res_refined_output)))
        except Exception:
            pass

        refined_answer = parsed_datum.answer.replace("$", "").replace(",", "")
        print(res_refined_output, refined_answer)
        return res_refined_output == refined_answer

hellaswag_solver

HellaswagSolverStrategy

Bases: SolverStrategy

Solver strategy for the MMLU dataset.

Source code in nearai/solvers/hellaswag_solver.py
class HellaswagSolverStrategy(SolverStrategy):
    """Solver strategy for the MMLU dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 8
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        return f"hellaswag_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["hellaswag"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = HellaswagDatum(**datum).model_dump()

        choices = ["A", "B", "C", "D"]
        example_problems_indices = list(range(0, 5 * self.shots, 5))
        example_problems = list(
            map(
                lambda d: HellaswagDatum(**d).model_dump(),
                [self.dataset_ref["validation"][i] for i in example_problems_indices],
            )
        )
        base_prompt = Template(
            open(PROMPTS_FOLDER / "hellaswag_verbose_answer.j2").read(),
            trim_blocks=True,
        ).render(
            example_problems=example_problems,
            challenge_problem=datum,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "hellaswag_extract_answer.j2").read(),
            trim_blocks=True,
        ).render(
            challenge_problem=datum,
            answer_text=response,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(extract_answer_prompt)

        try:
            answer = choices.index(response)
            return bool(answer == int(datum["label"]))
        except Exception:
            print("Failed to parse answer")
            return False

lean_solver

LeanSolverStrategy

Bases: SolverStrategy

Solver strategy to evaluate against Lean problems.

Source code in nearai/solvers/lean_solver.py
class LeanSolverStrategy(SolverStrategy):
    """Solver strategy to evaluate against Lean problems."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = ""
    ) -> None:
        super().__init__(model, agent)

    def evaluation_name(self) -> str:  # noqa: D102
        assert self.dataset_evaluation_name
        return self.dataset_evaluation_name

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["lean"]

    def solve(self, datum: dict) -> Tuple[bool, dict]:  # noqa: D102
        lean_datum = LeanDatum.model_validate(datum)
        lean_datum.url = load_repository(lean_datum.url)

        info: dict = {}
        info["verbose"] = {}

        lean_task = LeanTaskInfo(
            lean_datum.url,
            lean_datum.commit,
            lean_datum.filename,
            lean_datum.theorem,
            load_theorem(lean_datum),
        )
        info["verbose"]["theorem_raw"] = lean_task.theorem_raw

        base_prompt = Template(open(PROMPTS_FOLDER / "lean_answer.j2").read(), trim_blocks=True).render(
            url=lean_task.url,
            commit=lean_task.commit,
            filepath=lean_task.filename,
            theorem_name=lean_task.theorem,
            theorem_raw=lean_task.theorem_raw,
            begin_marker=BEGIN_MARKER,
            end_marker=END_MARKER,
        )
        response = self.start_inference_session("").run_task(base_prompt)

        json_response = extract_between_markers(response)
        if not json_response:
            info["error"] = "Failed to extract between markers."
            info["verbose"]["response"] = response
            return False, info

        tactics = parse_tactics(json_response)
        if not tactics:
            info["error"] = "Failed to parse tactics."
            info["verbose"]["response"] = json_response
            return False, info

        # Sometimes, there are timeout errors.
        num_attempts = 3
        info["tactics"] = tactics
        for i in range(0, num_attempts):
            if i != 0:
                info["check_solution_attempts"] = f"{i+1} (max: {num_attempts})"
            try:
                r, m = check_solution(lean_datum, tactics)
                if r:
                    info["verbose"]["check_solution_message"] = m
                else:
                    info["check_solution_message"] = m
                return r, info
            except Exception as e:
                if i == num_attempts - 1:
                    error_message = f"Exception while checking solution: {str(e)}."
                    print(error_message)
                    info["error"] = error_message
        return False, info
load_theorem
load_theorem(task: LeanDatum) -> str

Use local copy of the repository.

Source code in nearai/solvers/lean_solver.py
def load_theorem(task: LeanDatum) -> str:
    """Use local copy of the repository."""
    repo = LeanGitRepo(task.url, task.commit)
    theorem = Theorem(repo, task.filename, task.theorem)
    with Dojo(theorem) as (_, state):
        return state.pp

livebench_solver

LiveBenchSolverStrategy

Bases: SolverStrategy

Solver strategy for the live bench dataset.

Source code in nearai/solvers/livebench_solver.py
class LiveBenchSolverStrategy(SolverStrategy):
    """Solver strategy for the live bench dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: str, model: str = "", agent: str = "", step: str = "all"
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.step = step

    def evaluation_name(self) -> str:  # noqa: D102
        return "live_bench"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["live_bench"]

    def get_custom_tasks(self) -> List[dict]:  # noqa: D102
        return [{"summary": "all"}]

    @property
    def evaluated_entry_name(self) -> str:  # noqa: D102
        name = ""
        if self.agent:
            name = self.agent_name()
            if self.model_name != "":
                name += f"_with_model_{self.model_name}"
        else:
            name = self.model_name
        assert "/" not in name
        return name.lower()

    @SolverStrategyClassProperty
    def scoring_method(self) -> SolverScoringMethod:  # noqa: D102
        return SolverScoringMethod.Custom

    def solve(self, _datum: dict) -> Tuple[bool, dict]:  # noqa: D102
        if self.step == "gen_model_answer":
            self.gen_model_answer()
            return True, {}
        if self.step == "gen_ground_truth_judgement":
            return self.gen_ground_truth_judgement(), {}
        if self.step == "show_livebench_results":
            return self.show_livebench_results()
        if self.step == "all":
            self.gen_model_answer()
            if not self.gen_ground_truth_judgement():
                return False, {}
            return self.show_livebench_results()
        return False, {}

    def gen_model_answer(self) -> None:  # noqa: D102
        print("")
        print("----------- Step gen_model_answer -----------")
        print("")
        list_of_question_files = glob.glob(f"{self.dataset_ref}/**/question.jsonl", recursive=True)
        for question_file in list_of_question_files:
            questions = load_questions_jsonl(question_file)
            bench_name = os.path.dirname(question_file).split(str(self.dataset_ref))[-1]
            answer_file = _get_answer_file_path(bench_name, self.evaluated_entry_name)
            print(f"Questions from {question_file}")
            print(f"Output to {answer_file}")
            self.run_eval(questions, answer_file)

    def run_eval(self, questions, answer_file) -> None:  # noqa: D102
        answer_file = os.path.expanduser(answer_file)

        # Load existing answers
        existing_answers = set()
        if os.path.exists(answer_file):
            print(
                f"Answer file {answer_file} exists. Will skip already answered questions. Delete this file if that is not intended."  # noqa: E501
            )
            with open(answer_file, "r") as fin:
                for line in fin:
                    answer = json.loads(line)
                    existing_answers.add(answer["question_id"])

        for question in tqdm(questions):
            if question["question_id"] in existing_answers:
                continue
            choices = self.answer_question(question)

            ans_json = {
                "question_id": question["question_id"],
                "answer_id": shortuuid.uuid(),
                "model_id": self.evaluated_entry_name,
                "choices": choices,
                "tstamp": time.time(),
            }

            os.makedirs(os.path.dirname(answer_file), exist_ok=True)
            with open(answer_file, "a") as fout:
                fout.write(json.dumps(ans_json) + "\n")

    def answer_question(self, question) -> List[dict]:  # noqa: D102
        turns = []
        session = self.start_inference_session(question["question_id"])
        for qs in question["turns"]:
            output = session.run_task(qs)
            turns.append(output)

        return [{"index": 0, "turns": turns}]

    def gen_ground_truth_judgement(self) -> bool:  # noqa: D102
        print("")
        print("----------- Step gen_ground_truth_judgement -----------")
        print("")
        script_path = "nearai/projects/live_bench/gen_ground_truth_judgement.sh"

        try:
            # Run the script without capturing output
            subprocess.run(["/bin/bash", script_path, self.evaluated_entry_name, self.dataset_ref], check=True)
            return True

        except subprocess.CalledProcessError as e:
            print(f"An error occurred while running the script: {e}")
            return False

    def show_livebench_results(self) -> Tuple[bool, dict]:  # noqa: D102
        print("")
        print("----------- Step show_livebench_results -----------")
        print("")
        script_path = "nearai/projects/live_bench/show_livebench_results.sh"

        try:
            # Run the script without capturing output
            subprocess.run(["/bin/bash", script_path, self.evaluated_entry_name], check=True)

        except subprocess.CalledProcessError as e:
            print(f"An error occurred while running the script: {e}")
            return False, {}

        return self.create_result_dict()

    def read_csv_to_dict(self, file_path) -> dict:  # noqa: D102
        file_path = os.path.expanduser(file_path)
        with open(file_path, "r") as f:
            reader = csv.DictReader(f)
            matching_rows = [row for row in reader if row["model"] == self.evaluated_entry_name]
            return matching_rows[-1] if matching_rows else {}  # Get the last matching row

    def create_result_dict(self) -> Tuple[bool, dict]:  # noqa: D102
        tasks_data = self.read_csv_to_dict(_get_all_tasks_csv_file())
        groups_data = self.read_csv_to_dict(_get_all_groups_csv_file())

        if not tasks_data or not groups_data:
            return False, {}  # Return None if the model is not found in either file

        result: dict = {"tasks": {}, "groups": {}}

        for key, value in tasks_data.items():
            if key != "model":
                result["tasks"][key] = float(value)

        for key, value in groups_data.items():
            if key != "model":
                result["groups"][key] = float(value)

        return True, result

    def get_evaluation_metrics(self, tasks_results: List[Tuple[bool, Any]]) -> Dict[str, Any]:  # noqa: D102
        results: Dict[str, Dict[str, Any]] = tasks_results[-1][1]
        if len(results) == 0:
            raise ValueError("Cache empty. Rerun the job with --force. Use --step arg to specify a step.")
        metrics: Dict[str, Any] = {"average": results["groups"]["average"]}

        for group, score in results["groups"].items():
            if group == "average":
                continue
            metrics[f"group/{group}"] = score

        for task, score in results["tasks"].items():
            metrics[f"task/{task}"] = score

        return metrics

mbpp_solver

MBPPSolverStrategy

Bases: SolverStrategy

Solver strategy for the MBPP dataset.

Source code in nearai/solvers/mbpp_solver.py
class MBPPSolverStrategy(SolverStrategy):
    """Solver strategy for the MBPP dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 3
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        prefix = self.dataset_evaluation_name if self.dataset_evaluation_name else "mbpp"
        return f"{prefix}_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["mbpp"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = MBPPDatum(**datum).model_dump()

        ## Allow LLM to think "out loud" for it's answer
        function_name = get_function_name(datum["code"])
        example_problems = list(islice(self.dataset_ref["prompt"], self.shots))
        base_prompt = Template(open(PROMPTS_FOLDER / "mbpp_verbose_answer.j2").read(), trim_blocks=True).render(
            function_name=function_name,
            example_problems=example_problems,
            challenge_problem=datum,
        )
        response = self.start_inference_session(str(datum["task_id"])).run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "mbpp_extract_answer.j2").read(), trim_blocks=True
        ).render(
            function_name=function_name,
            answer_text=response,
        )
        response = self.start_inference_session(str(datum["task_id"])).run_task(extract_answer_prompt)

        ## Parse the python code
        python_code_blocks = parse_python_code_block(response) + parse_code_block(response)
        code = ""
        if len(python_code_blocks) == 0:
            code = response
        else:
            code = python_code_blocks[0]

        ## Evaluate the code
        try:
            for test in datum["test_list"] + datum["challenge_test_list"]:
                test_code = code + "\n" + test
                if not run_with_timeout(test_code):
                    return False
            return True
        except Exception:
            return False

mmlu_solver

MMLUSolverStrategy

Bases: SolverStrategy

Solver strategy for the MMLU dataset.

Source code in nearai/solvers/mmlu_solver.py
class MMLUSolverStrategy(SolverStrategy):
    """Solver strategy for the MMLU dataset."""

    def __init__(  # noqa: D107
        self, dataset_ref: Union[Dataset, DatasetDict], model: str = "", agent: str = "", shots: int = 8
    ) -> None:
        super().__init__(model, agent)
        self.dataset_ref = dataset_ref
        self.shots = shots

    def evaluation_name(self) -> str:  # noqa: D102
        prefix = self.dataset_evaluation_name if self.dataset_evaluation_name else "mmlu"
        return f"{prefix}_{self.shots}shots"

    def compatible_datasets(self) -> List[str]:  # noqa: D102
        return ["mmlu"]

    def solve(self, datum: dict) -> bool:  # noqa: D102
        datum = MMLUDatum(**datum).model_dump()

        choices = ["A", "B", "C", "D"]
        example_problems_indices = list(range(0, 5 * self.shots, 5))
        example_problems = list(
            map(
                lambda d: MMLUDatum(**d).model_dump(),
                [self.dataset_ref["dev"][i] for i in example_problems_indices],
            )
        )
        base_prompt = Template(open(PROMPTS_FOLDER / "mmlu_verbose_answer.j2").read(), trim_blocks=True).render(
            example_problems=example_problems,
            challenge_problem=datum,
            choices=choices,
        )

        response = self.start_inference_session("").run_task(base_prompt)

        ## Extract the answer from the response
        extract_answer_prompt = Template(
            open(PROMPTS_FOLDER / "mmlu_extract_answer.j2").read(), trim_blocks=True
        ).render(
            challenge_problem=datum,
            answer_text=response,
            choices=choices,
        )
        response = self.start_inference_session("").run_task(extract_answer_prompt)

        try:
            answer = choices.index(response)
            return bool(answer == datum["answer"])
        except Exception:
            print("Failed to parse answer")
            return False

tests

test_provider_models

TestMatchProviderModel

Bases: TestCase

Unit tests for get_provider_namespaced_model.

Source code in nearai/tests/test_provider_models.py
class TestMatchProviderModel(unittest.TestCase):
    """Unit tests for get_provider_namespaced_model."""

    def __init__(self, method_name="runTest"):  # noqa: D107
        super().__init__(method_name)
        self.provider_models = ProviderModels(CONFIG.get_client_config())

    def test_fireworks(self):  # noqa: D102
        self.assertEqual(
            self.provider_models.match_provider_model("fireworks::accounts/yi-01-ai/models/yi-large"),
            ("fireworks", "fireworks::accounts/yi-01-ai/models/yi-large"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("accounts/yi-01-ai/models/yi-large"),
            ("fireworks", "fireworks::accounts/yi-01-ai/models/yi-large"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("llama-v3-70b-instruct"),
            ("fireworks", "fireworks::accounts/fireworks/models/llama-v3-70b-instruct"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("yi-01-ai/yi-large"),
            ("fireworks", "fireworks::accounts/yi-01-ai/models/yi-large"),
        )

    def test_hyperbolic(self):  # noqa: D102
        self.assertEqual(
            self.provider_models.match_provider_model("hyperbolic::StableDiffusion"),
            ("hyperbolic", "hyperbolic::StableDiffusion"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("hyperbolic::meta-llama/Meta-Llama-3.1-70B-Instruct"),
            ("hyperbolic", "hyperbolic::meta-llama/Meta-Llama-3.1-70B-Instruct"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("hyperbolic::Meta-Llama-3.1-70B-Instruct"),
            ("hyperbolic", "hyperbolic::meta-llama/Meta-Llama-3.1-70B-Instruct"),
        )

    def test_registry_with_multiple_providers(self):  # noqa: D102
        self.assertEqual(
            self.provider_models.match_provider_model("llama-3.1-70b-instruct"),
            ("fireworks", "fireworks::accounts/fireworks/models/llama-v3p1-70b-instruct"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("llama-3.1-70b-instruct", provider="hyperbolic"),
            ("hyperbolic", "hyperbolic::meta-llama/Meta-Llama-3.1-70B-Instruct"),
        )
        self.assertEqual(
            self.provider_models.match_provider_model("near.ai/llama-3.1-70b-instruct", provider="hyperbolic"),
            ("hyperbolic", "hyperbolic::meta-llama/Meta-Llama-3.1-70B-Instruct"),
        )