contracts/lib/forge-std/test/StdStorage.t.sol 18.3 K raw
1
// SPDX-License-Identifier: MIT
2
pragma solidity >=0.7.0 <0.9.0;
3
4
import {stdStorage, StdStorage} from "../src/StdStorage.sol";
5
import {Test} from "../src/Test.sol";
6
7
contract StdStorageTest is Test {
8
    using stdStorage for StdStorage;
9
10
    StorageTest internal test;
11
12
    function setUp() public {
13
        test = new StorageTest();
14
    }
15
16
    function test_StorageHidden() public {
17
        assertEq(uint256(keccak256("my.random.var")), stdstore.target(address(test)).sig("hidden()").find());
18
    }
19
20
    function test_StorageObvious() public {
21
        assertEq(uint256(0), stdstore.target(address(test)).sig("exists()").find());
22
    }
23
24
    function test_StorageExtraSload() public {
25
        assertEq(16, stdstore.target(address(test)).sig(test.extra_sload.selector).find());
26
    }
27
28
    function test_StorageCheckedWriteHidden() public {
29
        stdstore.target(address(test)).sig(test.hidden.selector).checked_write(100);
30
        assertEq(uint256(test.hidden()), 100);
31
    }
32
33
    function test_StorageCheckedWriteObvious() public {
34
        stdstore.target(address(test)).sig(test.exists.selector).checked_write(100);
35
        assertEq(test.exists(), 100);
36
    }
37
38
    function test_StorageCheckedWriteSignedIntegerHidden() public {
39
        stdstore.target(address(test)).sig(test.hidden.selector).checked_write_int(-100);
40
        assertEq(int256(uint256(test.hidden())), -100);
41
    }
42
43
    function test_StorageCheckedWriteSignedIntegerObvious() public {
44
        stdstore.target(address(test)).sig(test.tG.selector).checked_write_int(-100);
45
        assertEq(test.tG(), -100);
46
    }
47
48
    function test_StorageMapStructA() public {
49
        uint256 slot =
50
            stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(0).find();
51
        assertEq(uint256(keccak256(abi.encode(address(this), 4))), slot);
52
    }
53
54
    function test_StorageMapStructB() public {
55
        uint256 slot =
56
            stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(1).find();
57
        assertEq(uint256(keccak256(abi.encode(address(this), 4))) + 1, slot);
58
    }
59
60
    function test_StorageDeepMap() public {
61
        uint256 slot = stdstore.target(address(test)).sig(test.deep_map.selector).with_key(address(this)).with_key(
62
            address(this)
63
        ).find();
64
        assertEq(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(5)))))), slot);
65
    }
66
67
    function test_StorageCheckedWriteDeepMap() public {
68
        stdstore.target(address(test)).sig(test.deep_map.selector).with_key(address(this)).with_key(address(this))
69
            .checked_write(100);
70
        assertEq(100, test.deep_map(address(this), address(this)));
71
    }
72
73
    function test_StorageDeepMapStructA() public {
74
        uint256 slot = stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this))
75
            .with_key(address(this)).depth(0).find();
76
        assertEq(
77
            bytes32(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(6)))))) + 0),
78
            bytes32(slot)
79
        );
80
    }
81
82
    function test_StorageDeepMapStructB() public {
83
        uint256 slot = stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this))
84
            .with_key(address(this)).depth(1).find();
85
        assertEq(
86
            bytes32(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(6)))))) + 1),
87
            bytes32(slot)
88
        );
89
    }
90
91
    function test_StorageCheckedWriteDeepMapStructA() public {
92
        stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)).with_key(
93
            address(this)
94
        ).depth(0).checked_write(100);
95
        (uint256 a, uint256 b) = test.deep_map_struct(address(this), address(this));
96
        assertEq(100, a);
97
        assertEq(0, b);
98
    }
99
100
    function test_StorageCheckedWriteDeepMapStructB() public {
101
        stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)).with_key(
102
            address(this)
103
        ).depth(1).checked_write(100);
104
        (uint256 a, uint256 b) = test.deep_map_struct(address(this), address(this));
105
        assertEq(0, a);
106
        assertEq(100, b);
107
    }
108
109
    function test_StorageCheckedWriteMapStructA() public {
110
        stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(0).checked_write(100);
111
        (uint256 a, uint256 b) = test.map_struct(address(this));
112
        assertEq(a, 100);
113
        assertEq(b, 0);
114
    }
115
116
    function test_StorageCheckedWriteMapStructB() public {
117
        stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(1).checked_write(100);
118
        (uint256 a, uint256 b) = test.map_struct(address(this));
119
        assertEq(a, 0);
120
        assertEq(b, 100);
121
    }
122
123
    function test_StorageStructA() public {
124
        uint256 slot = stdstore.target(address(test)).sig(test.basic.selector).depth(0).find();
125
        assertEq(uint256(7), slot);
126
    }
127
128
    function test_StorageStructB() public {
129
        uint256 slot = stdstore.target(address(test)).sig(test.basic.selector).depth(1).find();
130
        assertEq(uint256(7) + 1, slot);
131
    }
132
133
    function test_StorageCheckedWriteStructA() public {
134
        stdstore.target(address(test)).sig(test.basic.selector).depth(0).checked_write(100);
135
        (uint256 a, uint256 b) = test.basic();
136
        assertEq(a, 100);
137
        assertEq(b, 1337);
138
    }
139
140
    function test_StorageCheckedWriteStructB() public {
141
        stdstore.target(address(test)).sig(test.basic.selector).depth(1).checked_write(100);
142
        (uint256 a, uint256 b) = test.basic();
143
        assertEq(a, 1337);
144
        assertEq(b, 100);
145
    }
146
147
    function test_StorageMapAddrFound() public {
148
        uint256 slot = stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).find();
149
        assertEq(uint256(keccak256(abi.encode(address(this), uint256(1)))), slot);
150
    }
151
152
    function test_StorageMapAddrRoot() public {
153
        (uint256 slot, bytes32 key) =
154
            stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).parent();
155
        assertEq(address(uint160(uint256(key))), address(this));
156
        assertEq(uint256(1), slot);
157
        slot = stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).root();
158
        assertEq(uint256(1), slot);
159
    }
160
161
    function test_StorageMapUintFound() public {
162
        uint256 slot = stdstore.target(address(test)).sig(test.map_uint.selector).with_key(100).find();
163
        assertEq(uint256(keccak256(abi.encode(100, uint256(2)))), slot);
164
    }
165
166
    function test_StorageCheckedWriteMapUint() public {
167
        stdstore.target(address(test)).sig(test.map_uint.selector).with_key(100).checked_write(100);
168
        assertEq(100, test.map_uint(100));
169
    }
170
171
    function test_StorageCheckedWriteMapAddr() public {
172
        stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).checked_write(100);
173
        assertEq(100, test.map_addr(address(this)));
174
    }
175
176
    function test_StorageCheckedWriteMapBool() public {
177
        stdstore.target(address(test)).sig(test.map_bool.selector).with_key(address(this)).checked_write(true);
178
        assertTrue(test.map_bool(address(this)));
179
    }
180
181
    function testFuzz_StorageCheckedWriteMapPacked(address addr, uint128 value) public {
182
        stdstore.enable_packed_slots().target(address(test)).sig(test.read_struct_lower.selector).with_key(addr)
183
            .checked_write(value);
184
        assertEq(test.read_struct_lower(addr), value);
185
186
        stdstore.enable_packed_slots().target(address(test)).sig(test.read_struct_upper.selector).with_key(addr)
187
            .checked_write(value);
188
        assertEq(test.read_struct_upper(addr), value);
189
    }
190
191
    function test_StorageCheckedWriteMapPackedFullSuccess() public {
192
        uint256 full = test.map_packed(address(1337));
193
        // keep upper 128, set lower 128 to 1337
194
        full = (full & (uint256((1 << 128) - 1) << 128)) | 1337;
195
        stdstore.target(address(test)).sig(test.map_packed.selector).with_key(address(uint160(1337))).checked_write(
196
            full
197
        );
198
        assertEq(1337, test.read_struct_lower(address(1337)));
199
    }
200
201
    function test_RevertStorageConst() public {
202
        StorageTestTarget target = new StorageTestTarget(test);
203
204
        vm.expectRevert("stdStorage find(StdStorage): No storage use detected for target.");
205
        target.expectRevertStorageConst();
206
    }
207
208
    function testFuzz_StorageNativePack(uint248 val1, uint248 val2, bool boolVal1, bool boolVal2) public {
209
        stdstore.enable_packed_slots().target(address(test)).sig(test.tA.selector).checked_write(val1);
210
        stdstore.enable_packed_slots().target(address(test)).sig(test.tB.selector).checked_write(boolVal1);
211
        stdstore.enable_packed_slots().target(address(test)).sig(test.tC.selector).checked_write(boolVal2);
212
        stdstore.enable_packed_slots().target(address(test)).sig(test.tD.selector).checked_write(val2);
213
214
        assertEq(test.tA(), val1);
215
        assertEq(test.tB(), boolVal1);
216
        assertEq(test.tC(), boolVal2);
217
        assertEq(test.tD(), val2);
218
    }
219
220
    function test_StorageReadBytes32() public {
221
        bytes32 val = stdstore.target(address(test)).sig(test.tE.selector).read_bytes32();
222
        assertEq(val, hex"1337");
223
    }
224
225
    function test_StorageReadBool_False() public {
226
        bool val = stdstore.target(address(test)).sig(test.tB.selector).read_bool();
227
        assertEq(val, false);
228
    }
229
230
    function test_StorageReadBool_True() public {
231
        bool val = stdstore.target(address(test)).sig(test.tH.selector).read_bool();
232
        assertEq(val, true);
233
    }
234
235
    function test_RevertIf_ReadingNonBoolValue() public {
236
        vm.expectRevert("stdStorage read_bool(StdStorage): Cannot decode. Make sure you are reading a bool.");
237
        this.readNonBoolValue();
238
    }
239
240
    function readNonBoolValue() public {
241
        stdstore.target(address(test)).sig(test.tE.selector).read_bool();
242
    }
243
244
    function test_StorageReadAddress() public {
245
        address val = stdstore.target(address(test)).sig(test.tF.selector).read_address();
246
        assertEq(val, address(1337));
247
    }
248
249
    function test_StorageReadUint() public {
250
        uint256 val = stdstore.target(address(test)).sig(test.exists.selector).read_uint();
251
        assertEq(val, 1);
252
    }
253
254
    function test_StorageReadInt() public {
255
        int256 val = stdstore.target(address(test)).sig(test.tG.selector).read_int();
256
        assertEq(val, type(int256).min);
257
    }
258
259
    function testFuzz_Packed(uint256 val, uint8 elemToGet) public {
260
        // This function tries an assortment of packed slots, shifts meaning number of elements
261
        // that are packed. Shiftsizes are the size of each element, i.e. 8 means a data type that is 8 bits, 16 == 16 bits, etc.
262
        // Combined, these determine how a slot is packed. Making it random is too hard to avoid global rejection limit
263
        // and make it performant.
264
265
        // change the number of shifts
266
        for (uint256 i = 1; i < 5; i++) {
267
            uint256 shifts = i;
268
269
            elemToGet = uint8(bound(elemToGet, 0, shifts - 1));
270
271
            uint256[] memory shiftSizes = new uint256[](shifts);
272
            for (uint256 j; j < shifts; j++) {
273
                shiftSizes[j] = 8 * (j + 1);
274
            }
275
276
            test.setRandomPacking(val);
277
278
            uint256 leftBits;
279
            uint256 rightBits;
280
            for (uint256 j; j < shiftSizes.length; j++) {
281
                if (j < elemToGet) {
282
                    leftBits += shiftSizes[j];
283
                } else if (elemToGet != j) {
284
                    rightBits += shiftSizes[j];
285
                }
286
            }
287
288
            // we may have some right bits unaccounted for
289
            leftBits += 256 - (leftBits + shiftSizes[elemToGet] + rightBits);
290
            // clear left bits, then clear right bits and realign
291
            uint256 expectedValToRead = (val << leftBits) >> (leftBits + rightBits);
292
293
            uint256 readVal = stdstore.target(address(test)).enable_packed_slots().sig(
294
                "getRandomPacked(uint8,uint8[],uint8)"
295
            ).with_calldata(abi.encode(shifts, shiftSizes, elemToGet)).read_uint();
296
297
            assertEq(readVal, expectedValToRead);
298
        }
299
    }
300
301
    function testFuzz_Packed2(uint256 nvars, uint256 seed) public {
302
        // Number of random variables to generate.
303
        nvars = bound(nvars, 1, 20);
304
305
        // This will decrease as we generate values in the below loop.
306
        uint256 bitsRemaining = 256;
307
308
        // Generate a random value and size for each variable.
309
        uint256[] memory vals = new uint256[](nvars);
310
        uint256[] memory sizes = new uint256[](nvars);
311
        uint256[] memory offsets = new uint256[](nvars);
312
313
        for (uint256 i = 0; i < nvars; i++) {
314
            // Generate a random value and size.
315
            offsets[i] = i == 0 ? 0 : offsets[i - 1] + sizes[i - 1];
316
317
            uint256 nvarsRemaining = nvars - i;
318
            uint256 maxVarSize = bitsRemaining - nvarsRemaining + 1;
319
            sizes[i] = bound(uint256(keccak256(abi.encodePacked(seed, i + 256))), 1, maxVarSize);
320
            bitsRemaining -= sizes[i];
321
322
            uint256 maxVal;
323
            uint256 varSize = sizes[i];
324
            assembly {
325
                // mask = (1 << varSize) - 1
326
                maxVal := sub(shl(varSize, 1), 1)
327
            }
328
            vals[i] = bound(uint256(keccak256(abi.encodePacked(seed, i))), 0, maxVal);
329
        }
330
331
        // Pack all values into the slot.
332
        for (uint256 i = 0; i < nvars; i++) {
333
            stdstore.enable_packed_slots().target(address(test)).sig("getRandomPacked(uint256,uint256)").with_key(
334
                sizes[i]
335
            ).with_key(offsets[i]).checked_write(vals[i]);
336
        }
337
338
        // Verify the read data matches.
339
        for (uint256 i = 0; i < nvars; i++) {
340
            uint256 readVal = stdstore.enable_packed_slots().target(address(test)).sig(
341
                "getRandomPacked(uint256,uint256)"
342
            ).with_key(sizes[i]).with_key(offsets[i]).read_uint();
343
344
            uint256 retVal = test.getRandomPacked(sizes[i], offsets[i]);
345
346
            assertEq(readVal, vals[i]);
347
            assertEq(retVal, vals[i]);
348
        }
349
    }
350
351
    function testEdgeCaseArray() public {
352
        stdstore.target(address(test)).sig("edgeCaseArray(uint256)").with_key(uint256(0)).checked_write(1);
353
        assertEq(test.edgeCaseArray(0), 1);
354
    }
355
}
356
357
contract StorageTestTarget {
358
    using stdStorage for StdStorage;
359
360
    StdStorage internal stdstore;
361
    StorageTest internal test;
362
363
    constructor(StorageTest test_) {
364
        test = test_;
365
    }
366
367
    function expectRevertStorageConst() public {
368
        stdstore.target(address(test)).sig("const()").find();
369
    }
370
}
371
372
contract StorageTest {
373
    uint256 public exists = 1;
374
    mapping(address => uint256) public map_addr;
375
    mapping(uint256 => uint256) public map_uint;
376
    mapping(address => uint256) public map_packed;
377
    mapping(address => UnpackedStruct) public map_struct;
378
    mapping(address => mapping(address => uint256)) public deep_map;
379
    mapping(address => mapping(address => UnpackedStruct)) public deep_map_struct;
380
    UnpackedStruct public basic;
381
382
    uint248 public tA;
383
    bool public tB;
384
385
    bool public tC = false;
386
    uint248 public tD = 1;
387
388
    struct UnpackedStruct {
389
        uint256 a;
390
        uint256 b;
391
    }
392
393
    mapping(address => bool) public map_bool;
394
395
    bytes32 public tE = hex"1337";
396
    address public tF = address(1337);
397
    int256 public tG = type(int256).min;
398
    bool public tH = true;
399
    bytes32 private tI = ~bytes32(hex"1337");
400
401
    uint256 randomPacking;
402
403
    // Array with length matching values of elements.
404
    uint256[] public edgeCaseArray = [3, 3, 3];
405
406
    constructor() {
407
        basic = UnpackedStruct({a: 1337, b: 1337});
408
409
        uint256 two = (1 << 128) | 1;
410
        map_packed[msg.sender] = two;
411
        map_packed[address(uint160(1337))] = 1 << 128;
412
    }
413
414
    function read_struct_upper(address who) public view returns (uint256) {
415
        return map_packed[who] >> 128;
416
    }
417
418
    function read_struct_lower(address who) public view returns (uint256) {
419
        return map_packed[who] & ((1 << 128) - 1);
420
    }
421
422
    function hidden() public view returns (bytes32 t) {
423
        bytes32 slot = keccak256("my.random.var");
424
        /// @solidity memory-safe-assembly
425
        assembly {
426
            t := sload(slot)
427
        }
428
    }
429
430
    function const() public pure returns (bytes32 t) {
431
        t = bytes32(hex"1337");
432
    }
433
434
    function extra_sload() public view returns (bytes32 t) {
435
        // trigger read on slot `tE`, and make a staticcall to make sure compiler doesn't optimize this SLOAD away
436
        assembly {
437
            pop(staticcall(gas(), sload(tE.slot), 0, 0, 0, 0))
438
        }
439
        t = tI;
440
    }
441
442
    function setRandomPacking(uint256 val) public {
443
        randomPacking = val;
444
    }
445
446
    function _getMask(uint256 size) internal pure returns (uint256 mask) {
447
        assembly {
448
            // mask = (1 << size) - 1
449
            mask := sub(shl(size, 1), 1)
450
        }
451
    }
452
453
    function setRandomPacking(uint256 val, uint256 size, uint256 offset) public {
454
        // Generate mask based on the size of the value
455
        uint256 mask = _getMask(size);
456
        // Zero out all bits for the word we're about to set
457
        uint256 cleanedWord = randomPacking & ~(mask << offset);
458
        // Place val in the correct spot of the cleaned word
459
        randomPacking = cleanedWord | val << offset;
460
    }
461
462
    function getRandomPacked(uint256 size, uint256 offset) public view returns (uint256) {
463
        // Generate mask based on the size of the value
464
        uint256 mask = _getMask(size);
465
        // Shift to place the bits in the correct position, and use mask to zero out remaining bits
466
        return (randomPacking >> offset) & mask;
467
    }
468
469
    function getRandomPacked(uint8 shifts, uint8[] memory shiftSizes, uint8 elem) public view returns (uint256) {
470
        require(elem < shifts, "!elem");
471
        uint256 leftBits;
472
        uint256 rightBits;
473
474
        for (uint256 i; i < shiftSizes.length; i++) {
475
            if (i < elem) {
476
                leftBits += shiftSizes[i];
477
            } else if (elem != i) {
478
                rightBits += shiftSizes[i];
479
            }
480
        }
481
482
        // we may have some right bits unaccounted for
483
        leftBits += 256 - (leftBits + shiftSizes[elem] + rightBits);
484
485
        // clear left bits, then clear right bits and realign
486
        return (randomPacking << leftBits) >> (leftBits + rightBits);
487
    }
488
}