| 1 | // SPDX-License-Identifier: MIT |
| 2 | pragma solidity >=0.7.0 <0.9.0; |
| 3 | |
| 4 | import {stdStorage, StdStorage} from "../src/StdStorage.sol"; |
| 5 | import {Test} from "../src/Test.sol"; |
| 6 | |
| 7 | contract StdStorageTest is Test { |
| 8 | using stdStorage for StdStorage; |
| 9 | |
| 10 | StorageTest internal test; |
| 11 | |
| 12 | function setUp() public { |
| 13 | test = new StorageTest(); |
| 14 | } |
| 15 | |
| 16 | function test_StorageHidden() public { |
| 17 | assertEq(uint256(keccak256("my.random.var")), stdstore.target(address(test)).sig("hidden()").find()); |
| 18 | } |
| 19 | |
| 20 | function test_StorageObvious() public { |
| 21 | assertEq(uint256(0), stdstore.target(address(test)).sig("exists()").find()); |
| 22 | } |
| 23 | |
| 24 | function test_StorageExtraSload() public { |
| 25 | assertEq(16, stdstore.target(address(test)).sig(test.extra_sload.selector).find()); |
| 26 | } |
| 27 | |
| 28 | function test_StorageCheckedWriteHidden() public { |
| 29 | stdstore.target(address(test)).sig(test.hidden.selector).checked_write(100); |
| 30 | assertEq(uint256(test.hidden()), 100); |
| 31 | } |
| 32 | |
| 33 | function test_StorageCheckedWriteObvious() public { |
| 34 | stdstore.target(address(test)).sig(test.exists.selector).checked_write(100); |
| 35 | assertEq(test.exists(), 100); |
| 36 | } |
| 37 | |
| 38 | function test_StorageCheckedWriteSignedIntegerHidden() public { |
| 39 | stdstore.target(address(test)).sig(test.hidden.selector).checked_write_int(-100); |
| 40 | assertEq(int256(uint256(test.hidden())), -100); |
| 41 | } |
| 42 | |
| 43 | function test_StorageCheckedWriteSignedIntegerObvious() public { |
| 44 | stdstore.target(address(test)).sig(test.tG.selector).checked_write_int(-100); |
| 45 | assertEq(test.tG(), -100); |
| 46 | } |
| 47 | |
| 48 | function test_StorageMapStructA() public { |
| 49 | uint256 slot = |
| 50 | stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(0).find(); |
| 51 | assertEq(uint256(keccak256(abi.encode(address(this), 4))), slot); |
| 52 | } |
| 53 | |
| 54 | function test_StorageMapStructB() public { |
| 55 | uint256 slot = |
| 56 | stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(1).find(); |
| 57 | assertEq(uint256(keccak256(abi.encode(address(this), 4))) + 1, slot); |
| 58 | } |
| 59 | |
| 60 | function test_StorageDeepMap() public { |
| 61 | uint256 slot = stdstore.target(address(test)).sig(test.deep_map.selector).with_key(address(this)).with_key( |
| 62 | address(this) |
| 63 | ).find(); |
| 64 | assertEq(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(5)))))), slot); |
| 65 | } |
| 66 | |
| 67 | function test_StorageCheckedWriteDeepMap() public { |
| 68 | stdstore.target(address(test)).sig(test.deep_map.selector).with_key(address(this)).with_key(address(this)) |
| 69 | .checked_write(100); |
| 70 | assertEq(100, test.deep_map(address(this), address(this))); |
| 71 | } |
| 72 | |
| 73 | function test_StorageDeepMapStructA() public { |
| 74 | uint256 slot = stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)) |
| 75 | .with_key(address(this)).depth(0).find(); |
| 76 | assertEq( |
| 77 | bytes32(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(6)))))) + 0), |
| 78 | bytes32(slot) |
| 79 | ); |
| 80 | } |
| 81 | |
| 82 | function test_StorageDeepMapStructB() public { |
| 83 | uint256 slot = stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)) |
| 84 | .with_key(address(this)).depth(1).find(); |
| 85 | assertEq( |
| 86 | bytes32(uint256(keccak256(abi.encode(address(this), keccak256(abi.encode(address(this), uint256(6)))))) + 1), |
| 87 | bytes32(slot) |
| 88 | ); |
| 89 | } |
| 90 | |
| 91 | function test_StorageCheckedWriteDeepMapStructA() public { |
| 92 | stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)).with_key( |
| 93 | address(this) |
| 94 | ).depth(0).checked_write(100); |
| 95 | (uint256 a, uint256 b) = test.deep_map_struct(address(this), address(this)); |
| 96 | assertEq(100, a); |
| 97 | assertEq(0, b); |
| 98 | } |
| 99 | |
| 100 | function test_StorageCheckedWriteDeepMapStructB() public { |
| 101 | stdstore.target(address(test)).sig(test.deep_map_struct.selector).with_key(address(this)).with_key( |
| 102 | address(this) |
| 103 | ).depth(1).checked_write(100); |
| 104 | (uint256 a, uint256 b) = test.deep_map_struct(address(this), address(this)); |
| 105 | assertEq(0, a); |
| 106 | assertEq(100, b); |
| 107 | } |
| 108 | |
| 109 | function test_StorageCheckedWriteMapStructA() public { |
| 110 | stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(0).checked_write(100); |
| 111 | (uint256 a, uint256 b) = test.map_struct(address(this)); |
| 112 | assertEq(a, 100); |
| 113 | assertEq(b, 0); |
| 114 | } |
| 115 | |
| 116 | function test_StorageCheckedWriteMapStructB() public { |
| 117 | stdstore.target(address(test)).sig(test.map_struct.selector).with_key(address(this)).depth(1).checked_write(100); |
| 118 | (uint256 a, uint256 b) = test.map_struct(address(this)); |
| 119 | assertEq(a, 0); |
| 120 | assertEq(b, 100); |
| 121 | } |
| 122 | |
| 123 | function test_StorageStructA() public { |
| 124 | uint256 slot = stdstore.target(address(test)).sig(test.basic.selector).depth(0).find(); |
| 125 | assertEq(uint256(7), slot); |
| 126 | } |
| 127 | |
| 128 | function test_StorageStructB() public { |
| 129 | uint256 slot = stdstore.target(address(test)).sig(test.basic.selector).depth(1).find(); |
| 130 | assertEq(uint256(7) + 1, slot); |
| 131 | } |
| 132 | |
| 133 | function test_StorageCheckedWriteStructA() public { |
| 134 | stdstore.target(address(test)).sig(test.basic.selector).depth(0).checked_write(100); |
| 135 | (uint256 a, uint256 b) = test.basic(); |
| 136 | assertEq(a, 100); |
| 137 | assertEq(b, 1337); |
| 138 | } |
| 139 | |
| 140 | function test_StorageCheckedWriteStructB() public { |
| 141 | stdstore.target(address(test)).sig(test.basic.selector).depth(1).checked_write(100); |
| 142 | (uint256 a, uint256 b) = test.basic(); |
| 143 | assertEq(a, 1337); |
| 144 | assertEq(b, 100); |
| 145 | } |
| 146 | |
| 147 | function test_StorageMapAddrFound() public { |
| 148 | uint256 slot = stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).find(); |
| 149 | assertEq(uint256(keccak256(abi.encode(address(this), uint256(1)))), slot); |
| 150 | } |
| 151 | |
| 152 | function test_StorageMapAddrRoot() public { |
| 153 | (uint256 slot, bytes32 key) = |
| 154 | stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).parent(); |
| 155 | assertEq(address(uint160(uint256(key))), address(this)); |
| 156 | assertEq(uint256(1), slot); |
| 157 | slot = stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).root(); |
| 158 | assertEq(uint256(1), slot); |
| 159 | } |
| 160 | |
| 161 | function test_StorageMapUintFound() public { |
| 162 | uint256 slot = stdstore.target(address(test)).sig(test.map_uint.selector).with_key(100).find(); |
| 163 | assertEq(uint256(keccak256(abi.encode(100, uint256(2)))), slot); |
| 164 | } |
| 165 | |
| 166 | function test_StorageCheckedWriteMapUint() public { |
| 167 | stdstore.target(address(test)).sig(test.map_uint.selector).with_key(100).checked_write(100); |
| 168 | assertEq(100, test.map_uint(100)); |
| 169 | } |
| 170 | |
| 171 | function test_StorageCheckedWriteMapAddr() public { |
| 172 | stdstore.target(address(test)).sig(test.map_addr.selector).with_key(address(this)).checked_write(100); |
| 173 | assertEq(100, test.map_addr(address(this))); |
| 174 | } |
| 175 | |
| 176 | function test_StorageCheckedWriteMapBool() public { |
| 177 | stdstore.target(address(test)).sig(test.map_bool.selector).with_key(address(this)).checked_write(true); |
| 178 | assertTrue(test.map_bool(address(this))); |
| 179 | } |
| 180 | |
| 181 | function testFuzz_StorageCheckedWriteMapPacked(address addr, uint128 value) public { |
| 182 | stdstore.enable_packed_slots().target(address(test)).sig(test.read_struct_lower.selector).with_key(addr) |
| 183 | .checked_write(value); |
| 184 | assertEq(test.read_struct_lower(addr), value); |
| 185 | |
| 186 | stdstore.enable_packed_slots().target(address(test)).sig(test.read_struct_upper.selector).with_key(addr) |
| 187 | .checked_write(value); |
| 188 | assertEq(test.read_struct_upper(addr), value); |
| 189 | } |
| 190 | |
| 191 | function test_StorageCheckedWriteMapPackedFullSuccess() public { |
| 192 | uint256 full = test.map_packed(address(1337)); |
| 193 | // keep upper 128, set lower 128 to 1337 |
| 194 | full = (full & (uint256((1 << 128) - 1) << 128)) | 1337; |
| 195 | stdstore.target(address(test)).sig(test.map_packed.selector).with_key(address(uint160(1337))).checked_write( |
| 196 | full |
| 197 | ); |
| 198 | assertEq(1337, test.read_struct_lower(address(1337))); |
| 199 | } |
| 200 | |
| 201 | function test_RevertStorageConst() public { |
| 202 | StorageTestTarget target = new StorageTestTarget(test); |
| 203 | |
| 204 | vm.expectRevert("stdStorage find(StdStorage): No storage use detected for target."); |
| 205 | target.expectRevertStorageConst(); |
| 206 | } |
| 207 | |
| 208 | function testFuzz_StorageNativePack(uint248 val1, uint248 val2, bool boolVal1, bool boolVal2) public { |
| 209 | stdstore.enable_packed_slots().target(address(test)).sig(test.tA.selector).checked_write(val1); |
| 210 | stdstore.enable_packed_slots().target(address(test)).sig(test.tB.selector).checked_write(boolVal1); |
| 211 | stdstore.enable_packed_slots().target(address(test)).sig(test.tC.selector).checked_write(boolVal2); |
| 212 | stdstore.enable_packed_slots().target(address(test)).sig(test.tD.selector).checked_write(val2); |
| 213 | |
| 214 | assertEq(test.tA(), val1); |
| 215 | assertEq(test.tB(), boolVal1); |
| 216 | assertEq(test.tC(), boolVal2); |
| 217 | assertEq(test.tD(), val2); |
| 218 | } |
| 219 | |
| 220 | function test_StorageReadBytes32() public { |
| 221 | bytes32 val = stdstore.target(address(test)).sig(test.tE.selector).read_bytes32(); |
| 222 | assertEq(val, hex"1337"); |
| 223 | } |
| 224 | |
| 225 | function test_StorageReadBool_False() public { |
| 226 | bool val = stdstore.target(address(test)).sig(test.tB.selector).read_bool(); |
| 227 | assertEq(val, false); |
| 228 | } |
| 229 | |
| 230 | function test_StorageReadBool_True() public { |
| 231 | bool val = stdstore.target(address(test)).sig(test.tH.selector).read_bool(); |
| 232 | assertEq(val, true); |
| 233 | } |
| 234 | |
| 235 | function test_RevertIf_ReadingNonBoolValue() public { |
| 236 | vm.expectRevert("stdStorage read_bool(StdStorage): Cannot decode. Make sure you are reading a bool."); |
| 237 | this.readNonBoolValue(); |
| 238 | } |
| 239 | |
| 240 | function readNonBoolValue() public { |
| 241 | stdstore.target(address(test)).sig(test.tE.selector).read_bool(); |
| 242 | } |
| 243 | |
| 244 | function test_StorageReadAddress() public { |
| 245 | address val = stdstore.target(address(test)).sig(test.tF.selector).read_address(); |
| 246 | assertEq(val, address(1337)); |
| 247 | } |
| 248 | |
| 249 | function test_StorageReadUint() public { |
| 250 | uint256 val = stdstore.target(address(test)).sig(test.exists.selector).read_uint(); |
| 251 | assertEq(val, 1); |
| 252 | } |
| 253 | |
| 254 | function test_StorageReadInt() public { |
| 255 | int256 val = stdstore.target(address(test)).sig(test.tG.selector).read_int(); |
| 256 | assertEq(val, type(int256).min); |
| 257 | } |
| 258 | |
| 259 | function testFuzz_Packed(uint256 val, uint8 elemToGet) public { |
| 260 | // This function tries an assortment of packed slots, shifts meaning number of elements |
| 261 | // that are packed. Shiftsizes are the size of each element, i.e. 8 means a data type that is 8 bits, 16 == 16 bits, etc. |
| 262 | // Combined, these determine how a slot is packed. Making it random is too hard to avoid global rejection limit |
| 263 | // and make it performant. |
| 264 | |
| 265 | // change the number of shifts |
| 266 | for (uint256 i = 1; i < 5; i++) { |
| 267 | uint256 shifts = i; |
| 268 | |
| 269 | elemToGet = uint8(bound(elemToGet, 0, shifts - 1)); |
| 270 | |
| 271 | uint256[] memory shiftSizes = new uint256[](shifts); |
| 272 | for (uint256 j; j < shifts; j++) { |
| 273 | shiftSizes[j] = 8 * (j + 1); |
| 274 | } |
| 275 | |
| 276 | test.setRandomPacking(val); |
| 277 | |
| 278 | uint256 leftBits; |
| 279 | uint256 rightBits; |
| 280 | for (uint256 j; j < shiftSizes.length; j++) { |
| 281 | if (j < elemToGet) { |
| 282 | leftBits += shiftSizes[j]; |
| 283 | } else if (elemToGet != j) { |
| 284 | rightBits += shiftSizes[j]; |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | // we may have some right bits unaccounted for |
| 289 | leftBits += 256 - (leftBits + shiftSizes[elemToGet] + rightBits); |
| 290 | // clear left bits, then clear right bits and realign |
| 291 | uint256 expectedValToRead = (val << leftBits) >> (leftBits + rightBits); |
| 292 | |
| 293 | uint256 readVal = stdstore.target(address(test)).enable_packed_slots().sig( |
| 294 | "getRandomPacked(uint8,uint8[],uint8)" |
| 295 | ).with_calldata(abi.encode(shifts, shiftSizes, elemToGet)).read_uint(); |
| 296 | |
| 297 | assertEq(readVal, expectedValToRead); |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | function testFuzz_Packed2(uint256 nvars, uint256 seed) public { |
| 302 | // Number of random variables to generate. |
| 303 | nvars = bound(nvars, 1, 20); |
| 304 | |
| 305 | // This will decrease as we generate values in the below loop. |
| 306 | uint256 bitsRemaining = 256; |
| 307 | |
| 308 | // Generate a random value and size for each variable. |
| 309 | uint256[] memory vals = new uint256[](nvars); |
| 310 | uint256[] memory sizes = new uint256[](nvars); |
| 311 | uint256[] memory offsets = new uint256[](nvars); |
| 312 | |
| 313 | for (uint256 i = 0; i < nvars; i++) { |
| 314 | // Generate a random value and size. |
| 315 | offsets[i] = i == 0 ? 0 : offsets[i - 1] + sizes[i - 1]; |
| 316 | |
| 317 | uint256 nvarsRemaining = nvars - i; |
| 318 | uint256 maxVarSize = bitsRemaining - nvarsRemaining + 1; |
| 319 | sizes[i] = bound(uint256(keccak256(abi.encodePacked(seed, i + 256))), 1, maxVarSize); |
| 320 | bitsRemaining -= sizes[i]; |
| 321 | |
| 322 | uint256 maxVal; |
| 323 | uint256 varSize = sizes[i]; |
| 324 | assembly { |
| 325 | // mask = (1 << varSize) - 1 |
| 326 | maxVal := sub(shl(varSize, 1), 1) |
| 327 | } |
| 328 | vals[i] = bound(uint256(keccak256(abi.encodePacked(seed, i))), 0, maxVal); |
| 329 | } |
| 330 | |
| 331 | // Pack all values into the slot. |
| 332 | for (uint256 i = 0; i < nvars; i++) { |
| 333 | stdstore.enable_packed_slots().target(address(test)).sig("getRandomPacked(uint256,uint256)").with_key( |
| 334 | sizes[i] |
| 335 | ).with_key(offsets[i]).checked_write(vals[i]); |
| 336 | } |
| 337 | |
| 338 | // Verify the read data matches. |
| 339 | for (uint256 i = 0; i < nvars; i++) { |
| 340 | uint256 readVal = stdstore.enable_packed_slots().target(address(test)).sig( |
| 341 | "getRandomPacked(uint256,uint256)" |
| 342 | ).with_key(sizes[i]).with_key(offsets[i]).read_uint(); |
| 343 | |
| 344 | uint256 retVal = test.getRandomPacked(sizes[i], offsets[i]); |
| 345 | |
| 346 | assertEq(readVal, vals[i]); |
| 347 | assertEq(retVal, vals[i]); |
| 348 | } |
| 349 | } |
| 350 | |
| 351 | function testEdgeCaseArray() public { |
| 352 | stdstore.target(address(test)).sig("edgeCaseArray(uint256)").with_key(uint256(0)).checked_write(1); |
| 353 | assertEq(test.edgeCaseArray(0), 1); |
| 354 | } |
| 355 | } |
| 356 | |
| 357 | contract StorageTestTarget { |
| 358 | using stdStorage for StdStorage; |
| 359 | |
| 360 | StdStorage internal stdstore; |
| 361 | StorageTest internal test; |
| 362 | |
| 363 | constructor(StorageTest test_) { |
| 364 | test = test_; |
| 365 | } |
| 366 | |
| 367 | function expectRevertStorageConst() public { |
| 368 | stdstore.target(address(test)).sig("const()").find(); |
| 369 | } |
| 370 | } |
| 371 | |
| 372 | contract StorageTest { |
| 373 | uint256 public exists = 1; |
| 374 | mapping(address => uint256) public map_addr; |
| 375 | mapping(uint256 => uint256) public map_uint; |
| 376 | mapping(address => uint256) public map_packed; |
| 377 | mapping(address => UnpackedStruct) public map_struct; |
| 378 | mapping(address => mapping(address => uint256)) public deep_map; |
| 379 | mapping(address => mapping(address => UnpackedStruct)) public deep_map_struct; |
| 380 | UnpackedStruct public basic; |
| 381 | |
| 382 | uint248 public tA; |
| 383 | bool public tB; |
| 384 | |
| 385 | bool public tC = false; |
| 386 | uint248 public tD = 1; |
| 387 | |
| 388 | struct UnpackedStruct { |
| 389 | uint256 a; |
| 390 | uint256 b; |
| 391 | } |
| 392 | |
| 393 | mapping(address => bool) public map_bool; |
| 394 | |
| 395 | bytes32 public tE = hex"1337"; |
| 396 | address public tF = address(1337); |
| 397 | int256 public tG = type(int256).min; |
| 398 | bool public tH = true; |
| 399 | bytes32 private tI = ~bytes32(hex"1337"); |
| 400 | |
| 401 | uint256 randomPacking; |
| 402 | |
| 403 | // Array with length matching values of elements. |
| 404 | uint256[] public edgeCaseArray = [3, 3, 3]; |
| 405 | |
| 406 | constructor() { |
| 407 | basic = UnpackedStruct({a: 1337, b: 1337}); |
| 408 | |
| 409 | uint256 two = (1 << 128) | 1; |
| 410 | map_packed[msg.sender] = two; |
| 411 | map_packed[address(uint160(1337))] = 1 << 128; |
| 412 | } |
| 413 | |
| 414 | function read_struct_upper(address who) public view returns (uint256) { |
| 415 | return map_packed[who] >> 128; |
| 416 | } |
| 417 | |
| 418 | function read_struct_lower(address who) public view returns (uint256) { |
| 419 | return map_packed[who] & ((1 << 128) - 1); |
| 420 | } |
| 421 | |
| 422 | function hidden() public view returns (bytes32 t) { |
| 423 | bytes32 slot = keccak256("my.random.var"); |
| 424 | /// @solidity memory-safe-assembly |
| 425 | assembly { |
| 426 | t := sload(slot) |
| 427 | } |
| 428 | } |
| 429 | |
| 430 | function const() public pure returns (bytes32 t) { |
| 431 | t = bytes32(hex"1337"); |
| 432 | } |
| 433 | |
| 434 | function extra_sload() public view returns (bytes32 t) { |
| 435 | // trigger read on slot `tE`, and make a staticcall to make sure compiler doesn't optimize this SLOAD away |
| 436 | assembly { |
| 437 | pop(staticcall(gas(), sload(tE.slot), 0, 0, 0, 0)) |
| 438 | } |
| 439 | t = tI; |
| 440 | } |
| 441 | |
| 442 | function setRandomPacking(uint256 val) public { |
| 443 | randomPacking = val; |
| 444 | } |
| 445 | |
| 446 | function _getMask(uint256 size) internal pure returns (uint256 mask) { |
| 447 | assembly { |
| 448 | // mask = (1 << size) - 1 |
| 449 | mask := sub(shl(size, 1), 1) |
| 450 | } |
| 451 | } |
| 452 | |
| 453 | function setRandomPacking(uint256 val, uint256 size, uint256 offset) public { |
| 454 | // Generate mask based on the size of the value |
| 455 | uint256 mask = _getMask(size); |
| 456 | // Zero out all bits for the word we're about to set |
| 457 | uint256 cleanedWord = randomPacking & ~(mask << offset); |
| 458 | // Place val in the correct spot of the cleaned word |
| 459 | randomPacking = cleanedWord | val << offset; |
| 460 | } |
| 461 | |
| 462 | function getRandomPacked(uint256 size, uint256 offset) public view returns (uint256) { |
| 463 | // Generate mask based on the size of the value |
| 464 | uint256 mask = _getMask(size); |
| 465 | // Shift to place the bits in the correct position, and use mask to zero out remaining bits |
| 466 | return (randomPacking >> offset) & mask; |
| 467 | } |
| 468 | |
| 469 | function getRandomPacked(uint8 shifts, uint8[] memory shiftSizes, uint8 elem) public view returns (uint256) { |
| 470 | require(elem < shifts, "!elem"); |
| 471 | uint256 leftBits; |
| 472 | uint256 rightBits; |
| 473 | |
| 474 | for (uint256 i; i < shiftSizes.length; i++) { |
| 475 | if (i < elem) { |
| 476 | leftBits += shiftSizes[i]; |
| 477 | } else if (elem != i) { |
| 478 | rightBits += shiftSizes[i]; |
| 479 | } |
| 480 | } |
| 481 | |
| 482 | // we may have some right bits unaccounted for |
| 483 | leftBits += 256 - (leftBits + shiftSizes[elem] + rightBits); |
| 484 | |
| 485 | // clear left bits, then clear right bits and realign |
| 486 | return (randomPacking << leftBits) >> (leftBits + rightBits); |
| 487 | } |
| 488 | } |