contracts/lib/forge-std/scripts/vm.py 17.3 K raw
1
#!/usr/bin/env python3
2
3
import argparse
4
import copy
5
import json
6
import re
7
import subprocess
8
from enum import Enum as PyEnum
9
from pathlib import Path
10
from typing import Callable
11
from urllib import request
12
13
VoidFn = Callable[[], None]
14
15
CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json"
16
OUT_PATH = "src/Vm.sol"
17
18
VM_SAFE_DOC = """\
19
/// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may
20
/// result in Script simulations differing from on-chain execution. It is recommended to only use
21
/// these cheats in scripts.
22
"""
23
24
VM_DOC = """\
25
/// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used
26
/// in tests, but it is not recommended to use these cheats in scripts.
27
"""
28
29
30
def main():
31
    parser = argparse.ArgumentParser(
32
            description="Generate Vm.sol based on the cheatcodes json created by Foundry")
33
    parser.add_argument(
34
            "--from",
35
            metavar="PATH",
36
            dest="path",
37
            required=False,
38
            help="path to a json file containing the Vm interface, as generated by Foundry")
39
    args = parser.parse_args()
40
    json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8") if args.path is None else Path(args.path).read_text()
41
    contract = Cheatcodes.from_json(json_str)
42
43
    ccs = contract.cheatcodes
44
    ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs))
45
    ccs.sort(key=lambda cc: cc.func.id)
46
47
    safe = list(filter(lambda cc: cc.safety == "safe", ccs))
48
    safe.sort(key=CmpCheatcode)
49
    unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs))
50
    unsafe.sort(key=CmpCheatcode)
51
    assert len(safe) + len(unsafe) == len(ccs)
52
53
    prefix_with_group_headers(safe)
54
    prefix_with_group_headers(unsafe)
55
56
    out = ""
57
58
    out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n"
59
60
    pp = CheatcodesPrinter(
61
        spdx_identifier="MIT OR Apache-2.0",
62
        solidity_requirement=">=0.6.2 <0.9.0",
63
        abicoder_pragma=True,
64
    )
65
    pp.p_prelude()
66
    pp.prelude = False
67
    out += pp.finish()
68
69
    out += "\n\n"
70
    out += VM_SAFE_DOC
71
    vm_safe = Cheatcodes(
72
        # TODO: Custom errors were introduced in 0.8.4
73
        errors=[],  # contract.errors
74
        events=contract.events,
75
        enums=contract.enums,
76
        structs=contract.structs,
77
        cheatcodes=safe,
78
    )
79
    pp.p_contract(vm_safe, "VmSafe")
80
    out += pp.finish()
81
82
    out += "\n\n"
83
    out += VM_DOC
84
    vm_unsafe = Cheatcodes(
85
        errors=[],
86
        events=[],
87
        enums=[],
88
        structs=[],
89
        cheatcodes=unsafe,
90
    )
91
    pp.p_contract(vm_unsafe, "Vm", "VmSafe")
92
    out += pp.finish()
93
94
    # Compatibility with <0.8.0
95
    def memory_to_calldata(m: re.Match) -> str:
96
        return " calldata " + m.group(1)
97
98
    out = re.sub(r" memory (.*returns)", memory_to_calldata, out)
99
100
    with open(OUT_PATH, "w") as f:
101
        f.write(out)
102
103
    forge_fmt = ["forge", "fmt", OUT_PATH]
104
    res = subprocess.run(forge_fmt)
105
    assert res.returncode == 0, f"command failed: {forge_fmt}"
106
107
    print(f"Wrote to {OUT_PATH}")
108
109
110
class CmpCheatcode:
111
    cheatcode: "Cheatcode"
112
113
    def __init__(self, cheatcode: "Cheatcode"):
114
        self.cheatcode = cheatcode
115
116
    def __lt__(self, other: "CmpCheatcode") -> bool:
117
        return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0
118
119
    def __eq__(self, other: "CmpCheatcode") -> bool:
120
        return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0
121
122
    def __gt__(self, other: "CmpCheatcode") -> bool:
123
        return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0
124
125
126
def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int:
127
    if a.group != b.group:
128
        return -1 if a.group < b.group else 1
129
    if a.status != b.status:
130
        return -1 if a.status < b.status else 1
131
    if a.safety != b.safety:
132
        return -1 if a.safety < b.safety else 1
133
    if a.func.id != b.func.id:
134
        return -1 if a.func.id < b.func.id else 1
135
    return 0
136
137
138
# HACK: A way to add group header comments without having to modify printer code
139
def prefix_with_group_headers(cheats: list["Cheatcode"]):
140
    s = set()
141
    for i, cheat in enumerate(cheats):
142
        if cheat.group in s:
143
            continue
144
145
        s.add(cheat.group)
146
147
        c = copy.deepcopy(cheat)
148
        c.func.description = ""
149
        c.func.declaration = f"// ======== {group(c.group)} ========"
150
        cheats.insert(i, c)
151
    return cheats
152
153
154
def group(s: str) -> str:
155
    if s == "evm":
156
        return "EVM"
157
    if s == "json":
158
        return "JSON"
159
    return s[0].upper() + s[1:]
160
161
162
class Visibility(PyEnum):
163
    EXTERNAL: str = "external"
164
    PUBLIC: str = "public"
165
    INTERNAL: str = "internal"
166
    PRIVATE: str = "private"
167
168
    def __str__(self):
169
        return self.value
170
171
172
class Mutability(PyEnum):
173
    PURE: str = "pure"
174
    VIEW: str = "view"
175
    NONE: str = ""
176
177
    def __str__(self):
178
        return self.value
179
180
181
class Function:
182
    id: str
183
    description: str
184
    declaration: str
185
    visibility: Visibility
186
    mutability: Mutability
187
    signature: str
188
    selector: str
189
    selector_bytes: bytes
190
191
    def __init__(
192
        self,
193
        id: str,
194
        description: str,
195
        declaration: str,
196
        visibility: Visibility,
197
        mutability: Mutability,
198
        signature: str,
199
        selector: str,
200
        selector_bytes: bytes,
201
    ):
202
        self.id = id
203
        self.description = description
204
        self.declaration = declaration
205
        self.visibility = visibility
206
        self.mutability = mutability
207
        self.signature = signature
208
        self.selector = selector
209
        self.selector_bytes = selector_bytes
210
211
    @staticmethod
212
    def from_dict(d: dict) -> "Function":
213
        return Function(
214
            d["id"],
215
            d["description"],
216
            d["declaration"],
217
            Visibility(d["visibility"]),
218
            Mutability(d["mutability"]),
219
            d["signature"],
220
            d["selector"],
221
            bytes(d["selectorBytes"]),
222
        )
223
224
225
class Cheatcode:
226
    func: Function
227
    group: str
228
    status: str
229
    safety: str
230
231
    def __init__(self, func: Function, group: str, status: str, safety: str):
232
        self.func = func
233
        self.group = group
234
        self.status = status
235
        self.safety = safety
236
237
    @staticmethod
238
    def from_dict(d: dict) -> "Cheatcode":
239
        return Cheatcode(
240
            Function.from_dict(d["func"]),
241
            str(d["group"]),
242
            str(d["status"]),
243
            str(d["safety"]),
244
        )
245
246
247
class Error:
248
    name: str
249
    description: str
250
    declaration: str
251
252
    def __init__(self, name: str, description: str, declaration: str):
253
        self.name = name
254
        self.description = description
255
        self.declaration = declaration
256
257
    @staticmethod
258
    def from_dict(d: dict) -> "Error":
259
        return Error(**d)
260
261
262
class Event:
263
    name: str
264
    description: str
265
    declaration: str
266
267
    def __init__(self, name: str, description: str, declaration: str):
268
        self.name = name
269
        self.description = description
270
        self.declaration = declaration
271
272
    @staticmethod
273
    def from_dict(d: dict) -> "Event":
274
        return Event(**d)
275
276
277
class EnumVariant:
278
    name: str
279
    description: str
280
281
    def __init__(self, name: str, description: str):
282
        self.name = name
283
        self.description = description
284
285
286
class Enum:
287
    name: str
288
    description: str
289
    variants: list[EnumVariant]
290
291
    def __init__(self, name: str, description: str, variants: list[EnumVariant]):
292
        self.name = name
293
        self.description = description
294
        self.variants = variants
295
296
    @staticmethod
297
    def from_dict(d: dict) -> "Enum":
298
        return Enum(
299
            d["name"],
300
            d["description"],
301
            list(map(lambda v: EnumVariant(**v), d["variants"])),
302
        )
303
304
305
class StructField:
306
    name: str
307
    ty: str
308
    description: str
309
310
    def __init__(self, name: str, ty: str, description: str):
311
        self.name = name
312
        self.ty = ty
313
        self.description = description
314
315
316
class Struct:
317
    name: str
318
    description: str
319
    fields: list[StructField]
320
321
    def __init__(self, name: str, description: str, fields: list[StructField]):
322
        self.name = name
323
        self.description = description
324
        self.fields = fields
325
326
    @staticmethod
327
    def from_dict(d: dict) -> "Struct":
328
        return Struct(
329
            d["name"],
330
            d["description"],
331
            list(map(lambda f: StructField(**f), d["fields"])),
332
        )
333
334
335
class Cheatcodes:
336
    errors: list[Error]
337
    events: list[Event]
338
    enums: list[Enum]
339
    structs: list[Struct]
340
    cheatcodes: list[Cheatcode]
341
342
    def __init__(
343
        self,
344
        errors: list[Error],
345
        events: list[Event],
346
        enums: list[Enum],
347
        structs: list[Struct],
348
        cheatcodes: list[Cheatcode],
349
    ):
350
        self.errors = errors
351
        self.events = events
352
        self.enums = enums
353
        self.structs = structs
354
        self.cheatcodes = cheatcodes
355
356
    @staticmethod
357
    def from_dict(d: dict) -> "Cheatcodes":
358
        return Cheatcodes(
359
            errors=[Error.from_dict(e) for e in d["errors"]],
360
            events=[Event.from_dict(e) for e in d["events"]],
361
            enums=[Enum.from_dict(e) for e in d["enums"]],
362
            structs=[Struct.from_dict(e) for e in d["structs"]],
363
            cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]],
364
        )
365
366
    @staticmethod
367
    def from_json(s) -> "Cheatcodes":
368
        return Cheatcodes.from_dict(json.loads(s))
369
370
    @staticmethod
371
    def from_json_file(file_path: str) -> "Cheatcodes":
372
        with open(file_path, "r") as f:
373
            return Cheatcodes.from_dict(json.load(f))
374
375
376
class Item(PyEnum):
377
    ERROR: str = "error"
378
    EVENT: str = "event"
379
    ENUM: str = "enum"
380
    STRUCT: str = "struct"
381
    FUNCTION: str = "function"
382
383
384
class ItemOrder:
385
    _list: list[Item]
386
387
    def __init__(self, list: list[Item]) -> None:
388
        assert len(list) <= len(Item), "list must not contain more items than Item"
389
        assert len(list) == len(set(list)), "list must not contain duplicates"
390
        self._list = list
391
        pass
392
393
    def get_list(self) -> list[Item]:
394
        return self._list
395
396
    @staticmethod
397
    def default() -> "ItemOrder":
398
        return ItemOrder(
399
            [
400
                Item.ERROR,
401
                Item.EVENT,
402
                Item.ENUM,
403
                Item.STRUCT,
404
                Item.FUNCTION,
405
            ]
406
        )
407
408
409
class CheatcodesPrinter:
410
    buffer: str
411
412
    prelude: bool
413
    spdx_identifier: str
414
    solidity_requirement: str
415
    abicoder_v2: bool
416
417
    block_doc_style: bool
418
419
    indent_level: int
420
    _indent_str: str
421
422
    nl_str: str
423
424
    items_order: ItemOrder
425
426
    def __init__(
427
        self,
428
        buffer: str = "",
429
        prelude: bool = True,
430
        spdx_identifier: str = "UNLICENSED",
431
        solidity_requirement: str = "",
432
        abicoder_pragma: bool = False,
433
        block_doc_style: bool = False,
434
        indent_level: int = 0,
435
        indent_with: int | str = 4,
436
        nl_str: str = "\n",
437
        items_order: ItemOrder = ItemOrder.default(),
438
    ):
439
        self.prelude = prelude
440
        self.spdx_identifier = spdx_identifier
441
        self.solidity_requirement = solidity_requirement
442
        self.abicoder_v2 = abicoder_pragma
443
        self.block_doc_style = block_doc_style
444
        self.buffer = buffer
445
        self.indent_level = indent_level
446
        self.nl_str = nl_str
447
448
        if isinstance(indent_with, int):
449
            assert indent_with >= 0
450
            self._indent_str = " " * indent_with
451
        elif isinstance(indent_with, str):
452
            self._indent_str = indent_with
453
        else:
454
            assert False, "indent_with must be int or str"
455
456
        self.items_order = items_order
457
458
    def finish(self) -> str:
459
        ret = self.buffer.rstrip()
460
        self.buffer = ""
461
        return ret
462
463
    def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""):
464
        if self.prelude:
465
            self.p_prelude(contract)
466
467
        self._p_str("interface ")
468
        name = name.strip()
469
        if name != "":
470
            self._p_str(name)
471
            self._p_str(" ")
472
        if inherits != "":
473
            self._p_str("is ")
474
            self._p_str(inherits)
475
            self._p_str(" ")
476
        self._p_str("{")
477
        self._p_nl()
478
        self._with_indent(lambda: self._p_items(contract))
479
        self._p_str("}")
480
        self._p_nl()
481
482
    def _p_items(self, contract: Cheatcodes):
483
        for item in self.items_order.get_list():
484
            if item == Item.ERROR:
485
                self.p_errors(contract.errors)
486
            elif item == Item.EVENT:
487
                self.p_events(contract.events)
488
            elif item == Item.ENUM:
489
                self.p_enums(contract.enums)
490
            elif item == Item.STRUCT:
491
                self.p_structs(contract.structs)
492
            elif item == Item.FUNCTION:
493
                self.p_functions(contract.cheatcodes)
494
            else:
495
                assert False, f"unknown item {item}"
496
497
    def p_prelude(self, contract: Cheatcodes | None = None):
498
        self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}")
499
        self._p_nl()
500
501
        if self.solidity_requirement != "":
502
            req = self.solidity_requirement
503
        elif contract and len(contract.errors) > 0:
504
            req = ">=0.8.4 <0.9.0"
505
        else:
506
            req = ">=0.6.0 <0.9.0"
507
        self._p_str(f"pragma solidity {req};")
508
        self._p_nl()
509
510
        if self.abicoder_v2:
511
            self._p_str("pragma experimental ABIEncoderV2;")
512
            self._p_nl()
513
514
        self._p_nl()
515
516
    def p_errors(self, errors: list[Error]):
517
        for error in errors:
518
            self._p_line(lambda: self.p_error(error))
519
520
    def p_error(self, error: Error):
521
        self._p_comment(error.description, doc=True)
522
        self._p_line(lambda: self._p_str(error.declaration))
523
524
    def p_events(self, events: list[Event]):
525
        for event in events:
526
            self._p_line(lambda: self.p_event(event))
527
528
    def p_event(self, event: Event):
529
        self._p_comment(event.description, doc=True)
530
        self._p_line(lambda: self._p_str(event.declaration))
531
532
    def p_enums(self, enums: list[Enum]):
533
        for enum in enums:
534
            self._p_line(lambda: self.p_enum(enum))
535
536
    def p_enum(self, enum: Enum):
537
        self._p_comment(enum.description, doc=True)
538
        self._p_line(lambda: self._p_str(f"enum {enum.name} {{"))
539
        self._with_indent(lambda: self.p_enum_variants(enum.variants))
540
        self._p_line(lambda: self._p_str("}"))
541
542
    def p_enum_variants(self, variants: list[EnumVariant]):
543
        for i, variant in enumerate(variants):
544
            self._p_indent()
545
            self._p_comment(variant.description)
546
547
            self._p_indent()
548
            self._p_str(variant.name)
549
            if i < len(variants) - 1:
550
                self._p_str(",")
551
            self._p_nl()
552
553
    def p_structs(self, structs: list[Struct]):
554
        for struct in structs:
555
            self._p_line(lambda: self.p_struct(struct))
556
557
    def p_struct(self, struct: Struct):
558
        self._p_comment(struct.description, doc=True)
559
        self._p_line(lambda: self._p_str(f"struct {struct.name} {{"))
560
        self._with_indent(lambda: self.p_struct_fields(struct.fields))
561
        self._p_line(lambda: self._p_str("}"))
562
563
    def p_struct_fields(self, fields: list[StructField]):
564
        for field in fields:
565
            self._p_line(lambda: self.p_struct_field(field))
566
567
    def p_struct_field(self, field: StructField):
568
        self._p_comment(field.description)
569
        self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};"))
570
571
    def p_functions(self, cheatcodes: list[Cheatcode]):
572
        for cheatcode in cheatcodes:
573
            self._p_line(lambda: self.p_function(cheatcode.func))
574
575
    def p_function(self, func: Function):
576
        self._p_comment(func.description, doc=True)
577
        self._p_line(lambda: self._p_str(func.declaration))
578
579
    def _p_comment(self, s: str, doc: bool = False):
580
        s = s.strip()
581
        if s == "":
582
            return
583
584
        s = map(lambda line: line.lstrip(), s.split("\n"))
585
        if self.block_doc_style:
586
            self._p_str("/*")
587
            if doc:
588
                self._p_str("*")
589
            self._p_nl()
590
            for line in s:
591
                self._p_indent()
592
                self._p_str(" ")
593
                if doc:
594
                    self._p_str("* ")
595
                self._p_str(line)
596
                self._p_nl()
597
            self._p_indent()
598
            self._p_str(" */")
599
            self._p_nl()
600
        else:
601
            first_line = True
602
            for line in s:
603
                if not first_line:
604
                    self._p_indent()
605
                first_line = False
606
607
                if doc:
608
                    self._p_str("/// ")
609
                else:
610
                    self._p_str("// ")
611
                self._p_str(line)
612
                self._p_nl()
613
614
    def _with_indent(self, f: VoidFn):
615
        self._inc_indent()
616
        f()
617
        self._dec_indent()
618
619
    def _p_line(self, f: VoidFn):
620
        self._p_indent()
621
        f()
622
        self._p_nl()
623
624
    def _p_indented(self, f: VoidFn):
625
        self._p_indent()
626
        f()
627
628
    def _p_indent(self):
629
        for _ in range(self.indent_level):
630
            self._p_str(self._indent_str)
631
632
    def _p_nl(self):
633
        self._p_str(self.nl_str)
634
635
    def _p_str(self, txt: str):
636
        self.buffer += txt
637
638
    def _inc_indent(self):
639
        self.indent_level += 1
640
641
    def _dec_indent(self):
642
        self.indent_level -= 1
643
644
645
if __name__ == "__main__":
646
    main()