contracts/lib/forge-std/src/StdStorage.sol 17.4 K raw
1
// SPDX-License-Identifier: MIT
2
pragma solidity >=0.6.2 <0.9.0;
3
4
import {Vm} from "./Vm.sol";
5
6
struct FindData {
7
    uint256 slot;
8
    uint256 offsetLeft;
9
    uint256 offsetRight;
10
    bool found;
11
}
12
13
struct StdStorage {
14
    mapping(address => mapping(bytes4 => mapping(bytes32 => FindData))) finds;
15
    bytes32[] _keys;
16
    bytes4 _sig;
17
    uint256 _depth;
18
    address _target;
19
    bytes32 _set;
20
    bool _enable_packed_slots;
21
    bytes _calldata;
22
}
23
24
library stdStorageSafe {
25
    event SlotFound(address who, bytes4 fsig, bytes32 keysHash, uint256 slot);
26
    event WARNING_UninitedSlot(address who, uint256 slot);
27
28
    Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
29
    uint256 constant UINT256_MAX = 115792089237316195423570985008687907853269984665640564039457584007913129639935;
30
31
    function sigs(string memory sigStr) internal pure returns (bytes4) {
32
        return bytes4(keccak256(bytes(sigStr)));
33
    }
34
35
    function getCallParams(StdStorage storage self) internal view returns (bytes memory) {
36
        if (self._calldata.length == 0) {
37
            return flatten(self._keys);
38
        } else {
39
            return self._calldata;
40
        }
41
    }
42
43
    // Calls target contract with configured parameters
44
    function callTarget(StdStorage storage self) internal view returns (bool, bytes32) {
45
        bytes memory cald = abi.encodePacked(self._sig, getCallParams(self));
46
        (bool success, bytes memory rdat) = self._target.staticcall(cald);
47
        bytes32 result = bytesToBytes32(rdat, 32 * self._depth);
48
49
        return (success, result);
50
    }
51
52
    // Tries mutating slot value to determine if the targeted value is stored in it.
53
    // If current value is 0, then we are setting slot value to type(uint256).max
54
    // Otherwise, we set it to 0. That way, return value should always be affected.
55
    function checkSlotMutatesCall(StdStorage storage self, bytes32 slot) internal returns (bool) {
56
        bytes32 prevSlotValue = vm.load(self._target, slot);
57
        (bool success, bytes32 prevReturnValue) = callTarget(self);
58
59
        bytes32 testVal = prevReturnValue == bytes32(0) ? bytes32(UINT256_MAX) : bytes32(0);
60
        vm.store(self._target, slot, testVal);
61
62
        (, bytes32 newReturnValue) = callTarget(self);
63
64
        vm.store(self._target, slot, prevSlotValue);
65
66
        return (success && (prevReturnValue != newReturnValue));
67
    }
68
69
    // Tries setting one of the bits in slot to 1 until return value changes.
70
    // Index of resulted bit is an offset packed slot has from left/right side
71
    function findOffset(StdStorage storage self, bytes32 slot, bool left) internal returns (bool, uint256) {
72
        for (uint256 offset = 0; offset < 256; offset++) {
73
            uint256 valueToPut = left ? (1 << (255 - offset)) : (1 << offset);
74
            vm.store(self._target, slot, bytes32(valueToPut));
75
76
            (bool success, bytes32 data) = callTarget(self);
77
78
            if (success && (uint256(data) > 0)) {
79
                return (true, offset);
80
            }
81
        }
82
        return (false, 0);
83
    }
84
85
    function findOffsets(StdStorage storage self, bytes32 slot) internal returns (bool, uint256, uint256) {
86
        bytes32 prevSlotValue = vm.load(self._target, slot);
87
88
        (bool foundLeft, uint256 offsetLeft) = findOffset(self, slot, true);
89
        (bool foundRight, uint256 offsetRight) = findOffset(self, slot, false);
90
91
        // `findOffset` may mutate slot value, so we are setting it to initial value
92
        vm.store(self._target, slot, prevSlotValue);
93
        return (foundLeft && foundRight, offsetLeft, offsetRight);
94
    }
95
96
    function find(StdStorage storage self) internal returns (FindData storage) {
97
        return find(self, true);
98
    }
99
100
    /// @notice find an arbitrary storage slot given a function sig, input data, address of the contract and a value to check against
101
    // slot complexity:
102
    //  if flat, will be bytes32(uint256(uint));
103
    //  if map, will be keccak256(abi.encode(key, uint(slot)));
104
    //  if deep map, will be keccak256(abi.encode(key1, keccak256(abi.encode(key0, uint(slot)))));
105
    //  if map struct, will be bytes32(uint256(keccak256(abi.encode(key1, keccak256(abi.encode(key0, uint(slot)))))) + structFieldDepth);
106
    function find(StdStorage storage self, bool _clear) internal returns (FindData storage) {
107
        address who = self._target;
108
        bytes4 fsig = self._sig;
109
        uint256 field_depth = self._depth;
110
        bytes memory params = getCallParams(self);
111
112
        // calldata to test against
113
        if (self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))].found) {
114
            if (_clear) {
115
                clear(self);
116
            }
117
            return self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))];
118
        }
119
        vm.record();
120
        (, bytes32 callResult) = callTarget(self);
121
        (bytes32[] memory reads,) = vm.accesses(address(who));
122
123
        if (reads.length == 0) {
124
            revert("stdStorage find(StdStorage): No storage use detected for target.");
125
        } else {
126
            for (uint256 i = reads.length; --i >= 0;) {
127
                bytes32 prev = vm.load(who, reads[i]);
128
                if (prev == bytes32(0)) {
129
                    emit WARNING_UninitedSlot(who, uint256(reads[i]));
130
                }
131
132
                if (!checkSlotMutatesCall(self, reads[i])) {
133
                    continue;
134
                }
135
136
                (uint256 offsetLeft, uint256 offsetRight) = (0, 0);
137
138
                if (self._enable_packed_slots) {
139
                    bool found;
140
                    (found, offsetLeft, offsetRight) = findOffsets(self, reads[i]);
141
                    if (!found) {
142
                        continue;
143
                    }
144
                }
145
146
                // Check that value between found offsets is equal to the current call result
147
                uint256 curVal = (uint256(prev) & getMaskByOffsets(offsetLeft, offsetRight)) >> offsetRight;
148
149
                if (uint256(callResult) != curVal) {
150
                    continue;
151
                }
152
153
                emit SlotFound(who, fsig, keccak256(abi.encodePacked(params, field_depth)), uint256(reads[i]));
154
                self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))] =
155
                    FindData(uint256(reads[i]), offsetLeft, offsetRight, true);
156
                break;
157
            }
158
        }
159
160
        require(
161
            self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))].found,
162
            "stdStorage find(StdStorage): Slot(s) not found."
163
        );
164
165
        if (_clear) {
166
            clear(self);
167
        }
168
        return self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))];
169
    }
170
171
    function target(StdStorage storage self, address _target) internal returns (StdStorage storage) {
172
        self._target = _target;
173
        return self;
174
    }
175
176
    function sig(StdStorage storage self, bytes4 _sig) internal returns (StdStorage storage) {
177
        self._sig = _sig;
178
        return self;
179
    }
180
181
    function sig(StdStorage storage self, string memory _sig) internal returns (StdStorage storage) {
182
        self._sig = sigs(_sig);
183
        return self;
184
    }
185
186
    function with_calldata(StdStorage storage self, bytes memory _calldata) internal returns (StdStorage storage) {
187
        self._calldata = _calldata;
188
        return self;
189
    }
190
191
    function with_key(StdStorage storage self, address who) internal returns (StdStorage storage) {
192
        self._keys.push(bytes32(uint256(uint160(who))));
193
        return self;
194
    }
195
196
    function with_key(StdStorage storage self, uint256 amt) internal returns (StdStorage storage) {
197
        self._keys.push(bytes32(amt));
198
        return self;
199
    }
200
201
    function with_key(StdStorage storage self, bytes32 key) internal returns (StdStorage storage) {
202
        self._keys.push(key);
203
        return self;
204
    }
205
206
    function enable_packed_slots(StdStorage storage self) internal returns (StdStorage storage) {
207
        self._enable_packed_slots = true;
208
        return self;
209
    }
210
211
    function depth(StdStorage storage self, uint256 _depth) internal returns (StdStorage storage) {
212
        self._depth = _depth;
213
        return self;
214
    }
215
216
    function read(StdStorage storage self) private returns (bytes memory) {
217
        FindData storage data = find(self, false);
218
        uint256 mask = getMaskByOffsets(data.offsetLeft, data.offsetRight);
219
        uint256 value = (uint256(vm.load(self._target, bytes32(data.slot))) & mask) >> data.offsetRight;
220
        clear(self);
221
        return abi.encode(value);
222
    }
223
224
    function read_bytes32(StdStorage storage self) internal returns (bytes32) {
225
        return abi.decode(read(self), (bytes32));
226
    }
227
228
    function read_bool(StdStorage storage self) internal returns (bool) {
229
        int256 v = read_int(self);
230
        if (v == 0) return false;
231
        if (v == 1) return true;
232
        revert("stdStorage read_bool(StdStorage): Cannot decode. Make sure you are reading a bool.");
233
    }
234
235
    function read_address(StdStorage storage self) internal returns (address) {
236
        return abi.decode(read(self), (address));
237
    }
238
239
    function read_uint(StdStorage storage self) internal returns (uint256) {
240
        return abi.decode(read(self), (uint256));
241
    }
242
243
    function read_int(StdStorage storage self) internal returns (int256) {
244
        return abi.decode(read(self), (int256));
245
    }
246
247
    function parent(StdStorage storage self) internal returns (uint256, bytes32) {
248
        address who = self._target;
249
        uint256 field_depth = self._depth;
250
        vm.startMappingRecording();
251
        uint256 child = find(self, true).slot - field_depth;
252
        (bool found, bytes32 key, bytes32 parent_slot) = vm.getMappingKeyAndParentOf(who, bytes32(child));
253
        if (!found) {
254
            revert(
255
                "stdStorage read_bool(StdStorage): Cannot find parent. Make sure you give a slot and startMappingRecording() has been called."
256
            );
257
        }
258
        return (uint256(parent_slot), key);
259
    }
260
261
    function root(StdStorage storage self) internal returns (uint256) {
262
        address who = self._target;
263
        uint256 field_depth = self._depth;
264
        vm.startMappingRecording();
265
        uint256 child = find(self, true).slot - field_depth;
266
        bool found;
267
        bytes32 root_slot;
268
        bytes32 parent_slot;
269
        (found,, parent_slot) = vm.getMappingKeyAndParentOf(who, bytes32(child));
270
        if (!found) {
271
            revert(
272
                "stdStorage read_bool(StdStorage): Cannot find parent. Make sure you give a slot and startMappingRecording() has been called."
273
            );
274
        }
275
        while (found) {
276
            root_slot = parent_slot;
277
            (found,, parent_slot) = vm.getMappingKeyAndParentOf(who, bytes32(root_slot));
278
        }
279
        return uint256(root_slot);
280
    }
281
282
    function bytesToBytes32(bytes memory b, uint256 offset) private pure returns (bytes32) {
283
        bytes32 out;
284
285
        uint256 max = b.length > 32 ? 32 : b.length;
286
        for (uint256 i = 0; i < max; i++) {
287
            out |= bytes32(b[offset + i] & 0xFF) >> (i * 8);
288
        }
289
        return out;
290
    }
291
292
    function flatten(bytes32[] memory b) private pure returns (bytes memory) {
293
        bytes memory result = new bytes(b.length * 32);
294
        for (uint256 i = 0; i < b.length; i++) {
295
            bytes32 k = b[i];
296
            /// @solidity memory-safe-assembly
297
            assembly {
298
                mstore(add(result, add(32, mul(32, i))), k)
299
            }
300
        }
301
302
        return result;
303
    }
304
305
    function clear(StdStorage storage self) internal {
306
        delete self._target;
307
        delete self._sig;
308
        delete self._keys;
309
        delete self._depth;
310
        delete self._enable_packed_slots;
311
        delete self._calldata;
312
    }
313
314
    // Returns mask which contains non-zero bits for values between `offsetLeft` and `offsetRight`
315
    // (slotValue & mask) >> offsetRight will be the value of the given packed variable
316
    function getMaskByOffsets(uint256 offsetLeft, uint256 offsetRight) internal pure returns (uint256 mask) {
317
        // mask = ((1 << (256 - (offsetRight + offsetLeft))) - 1) << offsetRight;
318
        // using assembly because (1 << 256) causes overflow
319
        assembly {
320
            mask := shl(offsetRight, sub(shl(sub(256, add(offsetRight, offsetLeft)), 1), 1))
321
        }
322
    }
323
324
    // Returns slot value with updated packed variable.
325
    function getUpdatedSlotValue(bytes32 curValue, uint256 varValue, uint256 offsetLeft, uint256 offsetRight)
326
        internal
327
        pure
328
        returns (bytes32 newValue)
329
    {
330
        return bytes32((uint256(curValue) & ~getMaskByOffsets(offsetLeft, offsetRight)) | (varValue << offsetRight));
331
    }
332
}
333
334
library stdStorage {
335
    Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
336
337
    function sigs(string memory sigStr) internal pure returns (bytes4) {
338
        return stdStorageSafe.sigs(sigStr);
339
    }
340
341
    function find(StdStorage storage self) internal returns (uint256) {
342
        return find(self, true);
343
    }
344
345
    function find(StdStorage storage self, bool _clear) internal returns (uint256) {
346
        return stdStorageSafe.find(self, _clear).slot;
347
    }
348
349
    function target(StdStorage storage self, address _target) internal returns (StdStorage storage) {
350
        return stdStorageSafe.target(self, _target);
351
    }
352
353
    function sig(StdStorage storage self, bytes4 _sig) internal returns (StdStorage storage) {
354
        return stdStorageSafe.sig(self, _sig);
355
    }
356
357
    function sig(StdStorage storage self, string memory _sig) internal returns (StdStorage storage) {
358
        return stdStorageSafe.sig(self, _sig);
359
    }
360
361
    function with_key(StdStorage storage self, address who) internal returns (StdStorage storage) {
362
        return stdStorageSafe.with_key(self, who);
363
    }
364
365
    function with_key(StdStorage storage self, uint256 amt) internal returns (StdStorage storage) {
366
        return stdStorageSafe.with_key(self, amt);
367
    }
368
369
    function with_key(StdStorage storage self, bytes32 key) internal returns (StdStorage storage) {
370
        return stdStorageSafe.with_key(self, key);
371
    }
372
373
    function with_calldata(StdStorage storage self, bytes memory _calldata) internal returns (StdStorage storage) {
374
        return stdStorageSafe.with_calldata(self, _calldata);
375
    }
376
377
    function enable_packed_slots(StdStorage storage self) internal returns (StdStorage storage) {
378
        return stdStorageSafe.enable_packed_slots(self);
379
    }
380
381
    function depth(StdStorage storage self, uint256 _depth) internal returns (StdStorage storage) {
382
        return stdStorageSafe.depth(self, _depth);
383
    }
384
385
    function clear(StdStorage storage self) internal {
386
        stdStorageSafe.clear(self);
387
    }
388
389
    function checked_write(StdStorage storage self, address who) internal {
390
        checked_write(self, bytes32(uint256(uint160(who))));
391
    }
392
393
    function checked_write(StdStorage storage self, uint256 amt) internal {
394
        checked_write(self, bytes32(amt));
395
    }
396
397
    function checked_write_int(StdStorage storage self, int256 val) internal {
398
        checked_write(self, bytes32(uint256(val)));
399
    }
400
401
    function checked_write(StdStorage storage self, bool write) internal {
402
        bytes32 t;
403
        /// @solidity memory-safe-assembly
404
        assembly {
405
            t := write
406
        }
407
        checked_write(self, t);
408
    }
409
410
    function checked_write(StdStorage storage self, bytes32 set) internal {
411
        address who = self._target;
412
        bytes4 fsig = self._sig;
413
        uint256 field_depth = self._depth;
414
        bytes memory params = stdStorageSafe.getCallParams(self);
415
416
        if (!self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))].found) {
417
            find(self, false);
418
        }
419
        FindData storage data = self.finds[who][fsig][keccak256(abi.encodePacked(params, field_depth))];
420
        if ((data.offsetLeft + data.offsetRight) > 0) {
421
            uint256 maxVal = 2 ** (256 - (data.offsetLeft + data.offsetRight));
422
            require(
423
                uint256(set) < maxVal,
424
                string(
425
                    abi.encodePacked(
426
                        "stdStorage find(StdStorage): Packed slot. We can't fit value greater than ",
427
                        vm.toString(maxVal)
428
                    )
429
                )
430
            );
431
        }
432
        bytes32 curVal = vm.load(who, bytes32(data.slot));
433
        bytes32 valToSet = stdStorageSafe.getUpdatedSlotValue(curVal, uint256(set), data.offsetLeft, data.offsetRight);
434
435
        vm.store(who, bytes32(data.slot), valToSet);
436
437
        (bool success, bytes32 callResult) = stdStorageSafe.callTarget(self);
438
439
        if (!success || callResult != set) {
440
            vm.store(who, bytes32(data.slot), curVal);
441
            revert("stdStorage find(StdStorage): Failed to write value.");
442
        }
443
        clear(self);
444
    }
445
446
    function read_bytes32(StdStorage storage self) internal returns (bytes32) {
447
        return stdStorageSafe.read_bytes32(self);
448
    }
449
450
    function read_bool(StdStorage storage self) internal returns (bool) {
451
        return stdStorageSafe.read_bool(self);
452
    }
453
454
    function read_address(StdStorage storage self) internal returns (address) {
455
        return stdStorageSafe.read_address(self);
456
    }
457
458
    function read_uint(StdStorage storage self) internal returns (uint256) {
459
        return stdStorageSafe.read_uint(self);
460
    }
461
462
    function read_int(StdStorage storage self) internal returns (int256) {
463
        return stdStorageSafe.read_int(self);
464
    }
465
466
    function parent(StdStorage storage self) internal returns (uint256, bytes32) {
467
        return stdStorageSafe.parent(self);
468
    }
469
470
    function root(StdStorage storage self) internal returns (uint256) {
471
        return stdStorageSafe.root(self);
472
    }
473
}