-
Notifications
You must be signed in to change notification settings - Fork 300
Expand file tree
/
Copy pathpatch_prefetch_offset.py
More file actions
1241 lines (1065 loc) · 53.7 KB
/
Copy pathpatch_prefetch_offset.py
File metadata and controls
1241 lines (1065 loc) · 53.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from __future__ import annotations
"""Two-pass instruction-prefetch offset patcher.
Round 1: build with koffset=0 so the compiler emits s_prefetch_inst_pc_rel
with placeholder operands.
Round 2: assemble the GPU .s via llvm-mc, disassemble with llvm-objdump to
get exact hex addresses, compute correct PC-relative koffset/klength,
then patch both the .s file and the GPU ELF inside the fat .o via
direct binary patching (no recompilation needed, only a relink).
If the computed prefetch region has zero in-bounds cachelines, the 8-byte
s_prefetch_inst_pc_rel is replaced with 2× 4-byte s_nop 0.
Labels are discovered automatically from [ck_prefetch] / [ck_label] comments
in the generated .s assembly file — no source path needed.
Standalone usage (runs both rounds):
python patch_prefetch_offset.py \\
--build-dir /path/to/build \\
--target <cmake-target> \\
--objdump-mcpu gfx1201 \\
[--dry-run]
CMake PRE_LINK usage (round 1 already done by cmake, only patch the .o):
python patch_prefetch_offset.py \\
--build-dir /path/to/build \\
--target <cmake-target> \\
--objdump-mcpu gfx1201 \\
--skip-build-round1
"""
import argparse
import multiprocessing
import re
import shutil
import subprocess
import sys
from pathlib import Path
from typing import NamedTuple
# ---------------------------------------------------------------------------
# Module-level constants
# ---------------------------------------------------------------------------
CACHELINE_SIZE = 128 # bytes per instruction cache line
KLENGTH_SHIFT = 6
KLENGTH_MASK = 0x7F << KLENGTH_SHIFT # klength occupies bits [12:6] of dw0
KOFFSET_MASK = 0x00FFFFFF # 24-bit signed PC-relative offset in dw1[23:0]
S_NOP_ENCODING = 0xBF800000 # s_nop 0 — SOPP opcode 0, simm16=0
NOP_KLENGTH_SENTINEL = -1 # klength sentinel: replace prefetch with 2× s_nop
PLACE_MODE_DEFAULT = 0
PLACE_MODE_BLOCK_ENTRY = 1
DIR_FORWARD = "forward"
DIR_BACKWARD = "backward"
# ---------------------------------------------------------------------------
# Module-level regex patterns
# ---------------------------------------------------------------------------
# Function-header label in .s files
FUNC_LABEL_RE = re.compile(r"^([A-Za-z_][A-Za-z0-9_$.]*):\s*(?:;.*)?$")
# objdump function header (e.g. "0000000000001000 <funcname>:")
OBJDUMP_FUNC_RE = re.compile(r'^[0-9a-fA-F]+ <(.+?)>:\s*$')
# objdump instruction address from trailing comment (e.g. "// 00001000: F4...")
OBJDUMP_ADDR_RE = re.compile(r'//\s*([0-9a-fA-F]+):\s+[0-9a-fA-F]')
# Block label in .s (e.g. ".LBB1_3:")
BLOCK_LABEL_RE = re.compile(r'^\.[A-Za-z_]\w*:')
# ---------------------------------------------------------------------------
# Structured types for label classification
# ---------------------------------------------------------------------------
class PrefetchSite(NamedTuple):
"""A [ck_prefetch] marker in the merged .s ↔ objdump table."""
idx: int # index in merged list
direction: str # DIR_FORWARD or DIR_BACKWARD
offset_cl: int # cacheline offset from target
class TargetSite(NamedTuple):
"""A [ck_label] marker (INST_PREFETCH_TARGET) in the merged table."""
idx: int # index in merged list
mode: int # PLACE_MODE_DEFAULT or PLACE_MODE_BLOCK_ENTRY
class _Tee:
"""Write to both stdout and a log file simultaneously."""
def __init__(self, log_path: Path):
self._file = log_path.open("w", encoding="utf-8")
self._stdout = sys.stdout
def write(self, data: str) -> int:
self._stdout.write(data)
self._file.write(data)
return len(data)
def flush(self) -> None:
self._stdout.flush()
self._file.flush()
def close(self) -> None:
self._file.close()
sys.stdout = self._stdout
# ---------------------------------------------------------------------------
# Instruction classification
# ---------------------------------------------------------------------------
def is_asm_instruction(line: str) -> bool:
"""Return True if the line is an instruction (not a comment/label/directive/blank)."""
s = line.strip()
if not s:
return False
if s[0] in (';', '/', '.', '#'):
return False
if s.split()[0].endswith(':'):
return False
return True
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def run(cmd: list[str], cwd: Path) -> subprocess.CompletedProcess:
print(f"[run] {' '.join(cmd)}", flush=True)
result = subprocess.run(cmd, cwd=cwd, text=True, capture_output=True)
if result.stdout:
print(result.stdout, end="")
if result.stderr:
print(result.stderr, end="", file=sys.stderr)
if result.returncode != 0:
sys.exit(f"Command failed with exit code {result.returncode}")
return result
def cmake_build(build_dir: Path, target: str, jobs: int) -> None:
run(["cmake", "--build", str(build_dir), "--target", target, "-j", str(jobs), "--"], build_dir)
def find_asm_file(search_dir: Path, cpp_stem: str, gpu_arch: str = "") -> Path:
"""Find the GPU .s file produced by --save-temps."""
all_candidates = sorted(
search_dir.rglob(f"{cpp_stem}*.s"),
key=lambda p: p.stat().st_mtime, reverse=True,
)
if not all_candidates:
sys.exit(
f"No .s file matching '{cpp_stem}*.s' found under {search_dir}.\n"
"Make sure --save-temps is in the target's compile options."
)
def is_gpu(p: Path) -> bool:
n = p.name
if "-host-" in n:
return False
if "-hip-" in n:
return True
if gpu_arch and gpu_arch in n:
return True
return False
gpu = [p for p in all_candidates if is_gpu(p)]
chosen = gpu[0] if gpu else all_candidates[0]
if not gpu:
print(f"[warn] No GPU .s found; falling back to {chosen.name}")
return chosen
def find_obj_file(build_dir: Path, target: str) -> Path:
"""Find the most recent .o for the given CMake target.
Uses ``**`` under ``{target}.dir/`` so that multi-config generators
(e.g. Ninja Multi-Config, Visual Studio) whose object files live in a
config subdirectory like ``{target}.dir/Release/`` are also found.
"""
candidates = sorted(
build_dir.rglob(f"{target}.dir/**/*.o"),
key=lambda p: p.stat().st_mtime, reverse=True,
)
if not candidates:
sys.exit(
f"No .o file found under '*/{target}.dir/' in {build_dir}.\n"
"Check that the target was built before running the patch script."
)
return candidates[0]
def run_objdump(objdump_path: str, mcpu: str, obj_path: Path) -> str:
cmd = [objdump_path, f"--mcpu={mcpu}", "-d", str(obj_path)]
print(f"[run] {' '.join(cmd)}", flush=True)
result = subprocess.run(cmd, text=True, capture_output=True)
if result.returncode != 0:
sys.exit(f"objdump failed:\n{result.stderr}")
return result.stdout
def detect_mcpu_from_asm(asm_text: str) -> str:
"""Extract the GPU architecture from .amdgcn_target directive in the .s file.
Looks for lines like: .amdgcn_target "amdgcn-amd-amdhsa--gfx1201"
Returns the gfx* portion (e.g. "gfx1201") or empty string if not found.
"""
m = re.search(r'\.amdgcn_target\s+"[^"]*--(gfx[0-9a-zA-Z]+)', asm_text)
return m.group(1) if m else ""
# ---------------------------------------------------------------------------
# Label discovery from .s
# ---------------------------------------------------------------------------
def find_prefetch_labels_from_asm(asm_text: str) -> list[str]:
"""Return unique label names from [ck_prefetch] comments in the .s file."""
label_re = re.compile(r';\s*\[ck_prefetch\].*\bname\s*=\s*(\w+)')
seen: dict[str, None] = {}
for line in asm_text.splitlines():
m = label_re.search(line)
if m:
seen.setdefault(m.group(1), None)
return list(seen.keys())
# ---------------------------------------------------------------------------
# Assembly / objdump helpers
# ---------------------------------------------------------------------------
def assemble_gpu_asm(asm_file: Path, mcpu: str, objdump_path: str) -> Path:
"""Assemble GPU .s → temp .o via llvm-mc. Returns path (caller deletes)."""
llvm_mc = str(Path(objdump_path).parent / "llvm-mc")
out_obj = asm_file.with_suffix(".ck_tmp_patching.o")
run([llvm_mc, f"--mcpu={mcpu}", "--triple=amdgcn-amd-amdhsa",
"--filetype=obj", "-o", str(out_obj), str(asm_file)], asm_file.parent)
return out_obj
def parse_objdump_functions(objdump_text: str) -> dict[str, list[tuple[int, str]]]:
"""Parse llvm-objdump -d output into per-function (addr, instr_text) lists."""
instr_re = re.compile(r'^\t(.+?)//\s*([0-9a-fA-F]+):\s+[0-9a-fA-F]')
result: dict[str, list[tuple[int, str]]] = {}
cur_name: str | None = None
cur_entries: list[tuple[int, str]] = []
for line in objdump_text.splitlines():
m = OBJDUMP_FUNC_RE.match(line)
if m:
if cur_name is not None:
result[cur_name] = cur_entries
cur_name = m.group(1)
cur_entries = []
elif cur_name is not None:
m2 = instr_re.match(line)
if m2:
cur_entries.append((int(m2.group(2), 16), m2.group(1).strip()))
if cur_name is not None:
result[cur_name] = cur_entries
return result
def split_functions(asm_text: str) -> list[tuple[str, list[str]]]:
"""Split the .s file into per-function blocks."""
blocks: list[tuple[str, list[str]]] = []
current_name = "<top>"
current_lines: list[str] = []
for line in asm_text.splitlines():
m = FUNC_LABEL_RE.match(line)
if m:
if current_lines:
blocks.append((current_name, current_lines))
current_name = m.group(1)
current_lines = [line]
else:
current_lines.append(line)
if current_lines:
blocks.append((current_name, current_lines))
return blocks
# ---------------------------------------------------------------------------
# Merge .s ↔ objdump and compute koffsets
# ---------------------------------------------------------------------------
def _merge_s_and_objdump(s_lines: list[str],
obj_entries: list[tuple[int, str]]) -> list[tuple[int | None, str]]:
"""Pair each .s instruction with its objdump entry by mnemonic matching.
For each .s instruction we scan forward in objdump entries (up to
MAX_LOOKAHEAD) to find a matching mnemonic. This self-corrects drift
from assembler-inserted NOPs or classifier mismatches.
.p2align directives advance obj_idx to the next aligned entry.
Comment/directive/label lines get addr=None.
"""
MAX_LOOKAHEAD = 32
p2align_re = re.compile(r'\.p2align\s+(\d+)')
merged: list[tuple[int | None, str]] = []
obj_idx = 0
for line in s_lines:
m = p2align_re.search(line)
if m:
if obj_idx < len(obj_entries):
align = 1 << int(m.group(1))
while obj_idx < len(obj_entries) and (obj_entries[obj_idx][0] % align) != 0:
obj_idx += 1
merged.append((None, line))
continue
if is_asm_instruction(line):
if obj_idx < len(obj_entries):
s_mnem = line.strip().split()[0].lower()
for scan in range(obj_idx, min(obj_idx + MAX_LOOKAHEAD, len(obj_entries))):
if obj_entries[scan][1].split()[0].lower() == s_mnem:
obj_idx = scan
break
addr = obj_entries[obj_idx][0]
obj_idx += 1
else:
addr = None
merged.append((addr, line))
else:
merged.append((None, line))
return merged
def _merge_all_functions(asm_text: str, objdump_text: str,
dump_dir: Path | None = None
) -> dict[str, list[tuple[int | None, str]]]:
"""Merge .s ↔ objdump once per function. Returns {funcname: merged_list}.
Optionally dumps one file per function (not per label)."""
s_blocks = split_functions(asm_text)
obj_funcs = parse_objdump_functions(objdump_text)
merged_funcs: dict[str, list[tuple[int | None, str]]] = {}
for name, s_lines in s_blocks:
if name not in obj_funcs:
continue
merged = _merge_s_and_objdump(s_lines, obj_funcs[name])
merged_funcs[name] = merged
if dump_dir is not None:
safe = re.sub(r'[^A-Za-z0-9_]', '_', name)[:80]
dump_path = dump_dir / f"merged_{safe}.txt"
with dump_path.open('w', encoding='utf-8') as fh:
for idx, (addr, line) in enumerate(merged):
addr_str = f'0x{addr:08x}' if addr is not None else ' '
fh.write(f'[{idx:5d}] {addr_str} {line.rstrip()}\n')
print(f"[dump] Merged table written to {dump_path}")
return merged_funcs
def _resolve_target_address(
merged: list[tuple[int | None, str]],
tgt_idx: int,
tgt_mode: int,
name: str,
) -> int | None:
"""Resolve the target address for a prefetch's INST_PREFETCH_TARGET marker.
BLOCK_ENTRY mode (1): scan backward for the nearest block label, then
use the first instruction after it.
DEFAULT mode (0): use the first instruction after the [ck_label] comment.
"""
if tgt_mode == PLACE_MODE_BLOCK_ENTRY:
block_idx: int | None = None
for k in range(tgt_idx - 1, -1, -1):
if BLOCK_LABEL_RE.match(merged[k][1].strip()):
block_idx = k
break
scan_from = block_idx if block_idx is not None else tgt_idx
target: int | None = None
for k in range(scan_from + 1, len(merged)):
if merged[k][0] is not None:
target = merged[k][0]
break
if block_idx is not None and target is not None:
orig_target: int | None = None
for k in range(tgt_idx + 1, len(merged)):
if merged[k][0] is not None:
orig_target = merged[k][0]
break
if orig_target is not None and target != orig_target:
print(f"[adjust] {name[:60]!r}: BLOCK_ENTRY mode — "
f"block label at merged[{block_idx}] "
f"→ target 0x{target:x} (was 0x{orig_target:x}, "
f"saved {orig_target - target}B)")
return target
# DEFAULT mode (mode=0): first instruction after [ck_label].
for k in range(tgt_idx + 1, len(merged)):
if merged[k][0] is not None:
return merged[k][0]
return None
def _clamp_prefetch_region(
name: str,
pair_idx: int,
pc_next: int,
target: int,
orig_klength: int,
direction: str,
offset_cl: int,
func_end: int,
) -> tuple[int, int] | None:
"""Compute (koffset, klength) for one prefetch pair with OOB clamping.
*klength* may be ``NOP_KLENGTH_SENTINEL`` if the prefetch is entirely
out of bounds. Returns ``None`` if the pair should be skipped entirely
(e.g. negative forward koffset).
"""
target_aligned = target & ~(CACHELINE_SIZE - 1)
offset_bytes = offset_cl * CACHELINE_SIZE
klength = orig_klength
if direction == DIR_BACKWARD:
region_end = target_aligned + CACHELINE_SIZE + offset_bytes
region_start = region_end - (klength + 1) * CACHELINE_SIZE
min_base = (pc_next & ~(CACHELINE_SIZE - 1)) + CACHELINE_SIZE
if region_start < min_base:
region_start = min_base
usable = (region_end - region_start) // CACHELINE_SIZE
if usable <= 0:
klength = NOP_KLENGTH_SENTINEL
print(f"[nop] {name[:60]!r}: backward prefetch fully OOB "
f"(min_base 0x{min_base:x} >= region_end 0x{region_end:x}), "
f"replacing with 2× s_nop")
else:
klength = usable - 1
print(f"[clamp] {name[:60]!r}: backward start clamped "
f"(first cacheline after pc_next: 0x{min_base:x}), "
f"klength {orig_klength} → {klength}")
if klength == NOP_KLENGTH_SENTINEL:
print(f"[debug] func={name[:60]!r} pair {pair_idx}: "
f"pc_next=0x{pc_next:x} dir=backward → NOP (0 cachelines in bounds)")
return (0, NOP_KLENGTH_SENTINEL)
prefetch_base = region_start
koffset = prefetch_base - pc_next
print(f"[debug] func={name[:60]!r} pair {pair_idx}: "
f"pc_next=0x{pc_next:x} target=0x{target:x} dir=backward "
f"offset={offset_cl}cl prefetch_base=0x{prefetch_base:x} "
f"koffset=0x{koffset:x} ({koffset}B) klength={klength} "
f"region=[0x{region_start:x}, 0x{region_end:x}) "
f"({(region_end - region_start)}B = {klength + 1} cachelines)")
return (koffset, klength)
# ── Forward direction ────────────────────────────────────────────────
prefetch_base = target_aligned + offset_bytes
koffset = prefetch_base - pc_next
if koffset < 0:
print(f"[warn] {name[:60]!r}: negative koffset — target before prefetch, skipping")
return None
region_end = prefetch_base + (klength + 1) * CACHELINE_SIZE
if region_end > func_end:
needed = max(0, (func_end - prefetch_base + CACHELINE_SIZE - 1) // CACHELINE_SIZE)
if needed == 0:
klength = NOP_KLENGTH_SENTINEL
print(f"[nop] {name[:60]!r}: forward prefetch fully OOB "
f"(prefetch_base 0x{prefetch_base:x} >= func_end 0x{func_end:x}), "
f"replacing with 2× s_nop")
else:
klength = needed - 1
region_end = prefetch_base + (klength + 1) * CACHELINE_SIZE
print(f"[clamp] {name[:60]!r}: forward end clamped "
f"(func_end 0x{func_end:x}), klength {orig_klength} → {klength}")
if klength == NOP_KLENGTH_SENTINEL:
print(f"[debug] func={name[:60]!r} pair {pair_idx}: "
f"pc_next=0x{pc_next:x} dir=forward → NOP (0 cachelines in bounds)")
else:
region_start = prefetch_base
print(f"[debug] func={name[:60]!r} pair {pair_idx}: "
f"pc_next=0x{pc_next:x} target=0x{target:x} dir=forward "
f"offset={offset_cl}cl "
f"koffset=0x{koffset:x} ({koffset}B) klength={klength} "
f"region=[0x{region_start:x}, 0x{region_end:x}) "
f"({(region_end - region_start)}B = {klength + 1} cachelines)")
return (koffset, klength)
def find_best_koffset_hybrid(merged_funcs: dict[str, list[tuple[int | None, str]]],
label: str) -> dict[str, list[tuple[int, int]]]:
"""Compute per-function (koffset, klength) for INST_PREFETCH/INST_PREFETCH_TARGET label pairs.
Returns {funcname: [(koffset, klength), ...]} for each function containing
the given label. klength is clamped so the prefetch does not extend past
the end of the function; if no cachelines are in bounds, klength is set to
NOP_KLENGTH_SENTINEL and the prefetch will be replaced with 2× s_nop.
"""
# [ck_prefetch] marks INST_PREFETCH sites, [ck_label] marks INST_PREFETCH_TARGET targets.
prefetch_re = re.compile(rf";\s*\[ck_prefetch\].*\bname\s*=\s*{re.escape(label)}\b")
target_re = re.compile(rf";\s*\[ck_label\].*\bname\s*=\s*{re.escape(label)}\b")
either_re = re.compile(rf";\s*(?:\[ck_label\]|\[ck_prefetch\]).*\bname\s*=\s*{re.escape(label)}\b")
klength_re = re.compile(r's_prefetch_inst_pc_rel\s+\S+\s*,\s*\S+\s*,\s*(\d+)')
mode_re = re.compile(r'\bmode\s*=\s*(\d+)')
dir_re = re.compile(r'\bdir\s*=\s*(\w+)')
offset_re = re.compile(r'\boffset\s*=\s*(-?\d+)')
results: dict[str, list[tuple[int, int]]] = {}
for name, merged in merged_funcs.items():
if not any(either_re.search(line) for _, line in merged):
continue
# Determine function end address (for OOB clamping).
func_end: int = 0
for addr, _ in reversed(merged):
if addr is not None:
func_end = addr + 8 # conservative: largest instruction is 8 bytes
break
# Classify markers from .s comments.
prefetch_sites: list[PrefetchSite] = []
target_sites: list[TargetSite] = []
for idx, (_addr, line) in enumerate(merged):
if prefetch_re.search(line):
m_dir = dir_re.search(line)
m_off = offset_re.search(line)
prefetch_sites.append(PrefetchSite(
idx=idx,
direction=m_dir.group(1) if m_dir else DIR_FORWARD,
offset_cl=int(m_off.group(1)) if m_off else 0,
))
elif target_re.search(line):
m_mode = mode_re.search(line)
target_sites.append(TargetSite(
idx=idx,
mode=int(m_mode.group(1)) if m_mode else PLACE_MODE_DEFAULT,
))
pairs: list[tuple[int, int]] = []
for pf in prefetch_sites:
# Find the s_prefetch_inst_pc_rel instruction and parse its klength.
pf_instr_idx: int | None = None
orig_klength = 3 # default
j = pf.idx + 1
while j < len(merged) and 's_prefetch_inst_pc_rel' not in merged[j][1]:
j += 1
if j < len(merged):
pf_instr_idx = j
m_kl = klength_re.search(merged[j][1])
if m_kl:
orig_klength = int(m_kl.group(1))
# pc_next = address of the instruction after s_prefetch_inst_pc_rel.
pc_next: int | None = None
if pf_instr_idx is not None:
for k in range(pf_instr_idx + 1, len(merged)):
if merged[k][0] is not None:
pc_next = merged[k][0]
break
if pc_next is None:
print(f"[warn] {name[:60]!r}: no pc_next for prefetch at merged[{pf.idx}], skipping")
continue
# Find the nearest INST_PREFETCH_TARGET after this INST_PREFETCH.
tgt: TargetSite | None = None
for t in target_sites:
if t.idx > pf.idx:
tgt = t
break
if tgt is None:
# Unpaired prefetch — treat as forward with target=pc_next.
print(f"[warn] {name[:60]!r}: unpaired prefetch label at merged[{pf.idx}] — "
f"using koffset=0, clamping klength")
result = _clamp_prefetch_region(
name, len(pairs), pc_next, pc_next, orig_klength,
DIR_FORWARD, 0, func_end)
if result is not None:
pairs.append(result)
continue
target = _resolve_target_address(merged, tgt.idx, tgt.mode, name)
if target is None:
print(f"[warn] {name[:60]!r}: no target address for label at merged[{tgt.idx}]")
continue
result = _clamp_prefetch_region(
name, len(pairs), pc_next, target, orig_klength,
pf.direction, pf.offset_cl, func_end)
if result is not None:
pairs.append(result)
if pairs:
results[name] = pairs
if not results:
print(f"[skip] Label '{label}' not found in any matching function block.")
else:
total = sum(len(v) for v in results.values())
print(f"[offsets] {len(results)} function(s), {total} pair(s) with koffset for '{label}'.")
return results
# ---------------------------------------------------------------------------
# Patching
# ---------------------------------------------------------------------------
def patch_asm_s(asm_file: Path, func_koffsets: dict[str, list[tuple[int, int]]]) -> bool:
"""Patch s_prefetch_inst_pc_rel koffset and klength operands in the .s file.
Returns True if any change was made."""
prefetch_re = re.compile(
r'(s_prefetch_inst_pc_rel\s+)(?:0x[0-9a-fA-F]+|0|-?\d+)'
r'(\s*,\s*null\s*,\s*)(?:\d+)')
# Full-line regex to capture indentation and the entire prefetch instruction
# (used for NOP replacement).
prefetch_full_re = re.compile(
r'^(\s*)s_prefetch_inst_pc_rel\s+\S+\s*,\s*\S+\s*,\s*\d+(.*)$')
text = asm_file.read_text(encoding="utf-8", errors="replace")
out_lines: list[str] = []
current_func = "<top>"
func_pf_idx: dict[str, int] = {}
n_patched = 0
n_nopped = 0
for line in text.splitlines(keepends=True):
m = FUNC_LABEL_RE.match(line.rstrip())
if m:
current_func = m.group(1)
if current_func in func_koffsets:
pair_list = func_koffsets[current_func]
idx = func_pf_idx.get(current_func, 0)
if idx < len(pair_list):
koffset, klength = pair_list[idx]
if klength == NOP_KLENGTH_SENTINEL:
# Replace 8-byte prefetch with 2× 4-byte s_nop 0
m_full = prefetch_full_re.match(line.rstrip('\n\r'))
if m_full:
indent = m_full.group(1)
trailing = m_full.group(2) # e.g. comment
eol = line[len(line.rstrip('\n\r')):] # preserve \n
nop_lines = (f"{indent}s_nop 0{trailing}{eol}"
f"{indent}s_nop 0{eol}")
print(f"[patch-s] {current_func[:60]}: prefetch[{idx}] → "
f"2× s_nop 0 (OOB)")
func_pf_idx[current_func] = idx + 1
n_nopped += 1
n_patched += 1
out_lines.append(nop_lines)
continue
else:
koffset_str = hex(koffset)
new_line, n = prefetch_re.subn(
rf'\g<1>{koffset_str}\g<2>{klength}', line)
if n:
print(f"[patch-s] {current_func[:60]}: prefetch[{idx}] → "
f"koffset={koffset_str} klength={klength}")
func_pf_idx[current_func] = idx + 1
n_patched += n
out_lines.append(new_line)
continue
out_lines.append(line)
new_text = ''.join(out_lines)
print(f"[patch-s] {n_patched} operand(s) patched ({n_nopped} replaced with NOPs).")
if new_text == text:
print("[patch-s] No change.")
return False
asm_file.write_text(new_text, encoding="utf-8")
print(f"[patch-s] Written: {asm_file}")
return True
# ---------------------------------------------------------------------------
# ELF / fatbin helpers
# ---------------------------------------------------------------------------
def _find_elf_text_section(data: bytes | bytearray, base: int = 0) -> tuple[int, int, int] | None:
"""Find .text section in an ELF image starting at data[base:].
Returns (file_offset_from_base, size, vaddr) or None."""
import struct as _s
d = data[base:]
if len(d) < 64 or d[:4] != b'\x7fELF':
return None
ei_class, ei_data = d[4], d[5]
endian = '<' if ei_data == 1 else '>'
try:
if ei_class == 2:
e_shoff, = _s.unpack_from(f'{endian}Q', d, 40)
e_shentsize, = _s.unpack_from(f'{endian}H', d, 58)
e_shnum, = _s.unpack_from(f'{endian}H', d, 60)
e_shstrndx, = _s.unpack_from(f'{endian}H', d, 62)
addr_in_shdr, off_in_shdr, sz_in_shdr = 16, 24, 32
addr_fmt, off_fmt, sz_fmt = f'{endian}Q', f'{endian}Q', f'{endian}Q'
else:
e_shoff, = _s.unpack_from(f'{endian}I', d, 32)
e_shentsize, = _s.unpack_from(f'{endian}H', d, 46)
e_shnum, = _s.unpack_from(f'{endian}H', d, 48)
e_shstrndx, = _s.unpack_from(f'{endian}H', d, 50)
addr_in_shdr, off_in_shdr, sz_in_shdr = 12, 16, 20
addr_fmt, off_fmt, sz_fmt = f'{endian}I', f'{endian}I', f'{endian}I'
shstr_sh = e_shoff + e_shstrndx * e_shentsize
shstr_off, = _s.unpack_from(off_fmt, d, shstr_sh + off_in_shdr)
for i in range(e_shnum):
sh = e_shoff + i * e_shentsize
name_idx, = _s.unpack_from(f'{endian}I', d, sh)
ns = shstr_off + name_idx
ne = d.index(b'\x00', ns)
if d[ns:ne] == b'.text':
sec_addr, = _s.unpack_from(addr_fmt, d, sh + addr_in_shdr)
sec_off, = _s.unpack_from(off_fmt, d, sh + off_in_shdr)
sec_sz, = _s.unpack_from(sz_fmt, d, sh + sz_in_shdr)
return (sec_off, sec_sz, sec_addr)
except (_s.error, ValueError):
pass
return None
def _find_gpu_bundle(data: bytes | bytearray, tag: str = "fatbin"
) -> tuple[int, int, int, str] | None:
"""Locate the GPU bundle in a fat .o / fatbin.
Returns (magic_idx, gpu_off, gpu_sz, gpu_triple) or None.
*gpu_off* is relative to *magic_idx* (the absolute start of the GPU ELF
in *data* is ``magic_idx + gpu_off``).
"""
import struct as _s
MAGIC = b'__CLANG_OFFLOAD_BUNDLE__'
magic_idx = data.find(MAGIC)
if magic_idx < 0:
print(f"[{tag}] __CLANG_OFFLOAD_BUNDLE__ magic not found")
return None
hdr = magic_idx + len(MAGIC)
if hdr + 8 > len(data):
print(f"[{tag}] Truncated fatbin header")
return None
num_bundles, = _s.unpack_from('<Q', data, hdr)
cur = hdr + 8
gpu_off = gpu_sz = 0
gpu_triple = ""
for _ in range(num_bundles):
if cur + 24 > len(data):
break
off, sz, triple_sz = _s.unpack_from('<QQQ', data, cur)
cur += 24
if triple_sz == 0 or triple_sz > 512 or cur + triple_sz > len(data):
break
triple = data[cur:cur + triple_sz].decode('utf-8', errors='replace')
cur += triple_sz
if 'amdgcn' in triple or (triple.startswith('hip') and 'host' not in triple):
gpu_off, gpu_sz, gpu_triple = off, sz, triple
if not gpu_triple:
print(f"[{tag}] No GPU entry found in fatbin header")
return None
return (magic_idx, gpu_off, gpu_sz, gpu_triple)
def _objdump_gpu_elf(data: bytes | bytearray, abs_gpu_start: int, gpu_sz: int,
mcpu: str, objdump_path: str, tmp_path: Path,
tag: str = "fatbin") -> str | None:
"""Extract the GPU ELF from *data*, run objdump -d, return the text or None."""
try:
tmp_path.write_bytes(bytes(data[abs_gpu_start:abs_gpu_start + gpu_sz]))
result = subprocess.run(
[objdump_path, f"--mcpu={mcpu}", "-d", str(tmp_path)],
text=True, capture_output=True,
)
if result.returncode != 0:
print(f"[{tag}] objdump on GPU ELF failed: {result.stderr[:200]}")
return None
return result.stdout
finally:
tmp_path.unlink(missing_ok=True)
def _patch_one_prefetch(fat_data: bytearray, instr_pos: int, instr_va: int,
idx: int, new_koffset: int, new_klength: int) -> None:
"""Patch a single s_prefetch_inst_pc_rel at *instr_pos* in *fat_data*.
If *new_klength* is NOP_KLENGTH_SENTINEL, replaces the 8-byte instruction
with 2× s_nop 0. Otherwise patches koffset (dw1[23:0]) and klength
(dw0[12:6]) in place.
"""
import struct as _struct
old_dw0 = _struct.unpack_from('<I', fat_data, instr_pos)[0]
old_dw1 = _struct.unpack_from('<I', fat_data, instr_pos + 4)[0]
raw_before = fat_data[instr_pos:instr_pos + 8]
print(f"[patch-obj] VA 0x{instr_va:x}: BEFORE "
f"dw0=0x{old_dw0:08x} dw1=0x{old_dw1:08x} "
f"raw={raw_before.hex()} "
f"klength_bits12_6={(old_dw0 >> KLENGTH_SHIFT) & 0x7F}")
if new_klength == NOP_KLENGTH_SENTINEL:
_struct.pack_into('<I', fat_data, instr_pos, S_NOP_ENCODING)
_struct.pack_into('<I', fat_data, instr_pos + 4, S_NOP_ENCODING)
raw_after = fat_data[instr_pos:instr_pos + 8]
print(f"[patch-obj] VA 0x{instr_va:x}: AFTER "
f"2× s_nop 0 (0x{S_NOP_ENCODING:08x}) "
f"raw={raw_after.hex()}")
print(f"[patch-obj] VA 0x{instr_va:x}: prefetch[{idx}] → "
f"2× s_nop 0 (OOB)")
else:
first_dword = _struct.unpack_from('<I', fat_data, instr_pos)[0]
first_dword = (first_dword & ~KLENGTH_MASK) | ((new_klength & 0x7F) << KLENGTH_SHIFT)
_struct.pack_into('<I', fat_data, instr_pos, first_dword)
second_dword = _struct.unpack_from('<I', fat_data, instr_pos + 4)[0]
second_dword = (second_dword & ~KOFFSET_MASK) | (new_koffset & KOFFSET_MASK)
_struct.pack_into('<I', fat_data, instr_pos + 4, second_dword)
raw_after = fat_data[instr_pos:instr_pos + 8]
new_dw0 = _struct.unpack_from('<I', fat_data, instr_pos)[0]
new_dw1 = _struct.unpack_from('<I', fat_data, instr_pos + 4)[0]
print(f"[patch-obj] VA 0x{instr_va:x}: AFTER "
f"dw0=0x{new_dw0:08x} dw1=0x{new_dw1:08x} "
f"raw={raw_after.hex()} "
f"klength_bits12_6={(new_dw0 >> KLENGTH_SHIFT) & 0x7F}")
print(f"[patch-obj] VA 0x{instr_va:x}: prefetch[{idx}] → "
f"koffset={hex(new_koffset)} klength={new_klength}")
def replace_gpu_in_fatobj(fat_obj: Path, mcpu: str, objdump_path: str,
func_koffsets: dict[str, list[tuple[int, int]]]) -> bool:
"""Patch s_prefetch_inst_pc_rel koffsets and klengths directly in the GPU ELF
embedded in the fat .o via direct binary patching. Returns True on success."""
fat_data = bytearray(fat_obj.read_bytes())
bundle = _find_gpu_bundle(fat_data, tag="patch-obj")
if bundle is None:
return False
magic_idx, gpu_off, gpu_sz, gpu_triple = bundle
abs_gpu_start = magic_idx + gpu_off
print(f"[patch-obj] GPU bundle: '{gpu_triple}' abs=0x{abs_gpu_start:x} size={gpu_sz}")
if fat_data[abs_gpu_start:abs_gpu_start + 4] != b'\x7fELF':
print("[patch-obj] GPU data does not start with ELF magic")
return False
text_info = _find_elf_text_section(fat_data, abs_gpu_start)
if text_info is None:
print("[patch-obj] Cannot find .text in GPU ELF")
return False
text_foff, text_sz, text_va = text_info
print(f"[patch-obj] .text: foff=0x{text_foff:x} size={text_sz} va=0x{text_va:x}")
objdump_text = _objdump_gpu_elf(fat_data, abs_gpu_start, gpu_sz, mcpu, objdump_path,
fat_obj.with_suffix(".ck_gpu_elf_tmp"), tag="patch-obj")
if objdump_text is None:
return False
n_patched = 0
current_func: str | None = None
func_pf_idx: dict[str, int] = {}
for line in objdump_text.splitlines():
m = OBJDUMP_FUNC_RE.match(line)
if m:
current_func = m.group(1)
continue
if not current_func or current_func not in func_koffsets:
continue
if 's_prefetch_inst_pc_rel' not in line:
continue
m2 = OBJDUMP_ADDR_RE.search(line)
if not m2:
continue
idx = func_pf_idx.get(current_func, 0)
pair_list = func_koffsets[current_func]
if idx >= len(pair_list):
continue
instr_va = int(m2.group(1), 16)
new_koffset, new_klength = pair_list[idx]
instr_pos = abs_gpu_start + text_foff + (instr_va - text_va)
_patch_one_prefetch(fat_data, instr_pos, instr_va, idx,
new_koffset, new_klength)
func_pf_idx[current_func] = idx + 1
n_patched += 1
if n_patched == 0:
print("[patch-obj] No s_prefetch_inst_pc_rel found to patch")
return False
fat_obj.write_bytes(bytes(fat_data))
# Sanity check: re-read and verify the write persisted
import hashlib
written_hash = hashlib.md5(fat_obj.read_bytes()).hexdigest()
expected_hash = hashlib.md5(bytes(fat_data)).hexdigest()
if written_hash != expected_hash:
print(f"[patch-obj] WARNING: write verification failed! "
f"expected={expected_hash} written={written_hash}")
print(f"[patch-obj] Patched {n_patched} instruction(s) in {fat_obj.name} "
f"md5={written_hash}")
return True
def verify_patched_obj(fat_obj: Path, mcpu: str, objdump_path: str,
func_koffsets: dict[str, list[tuple[int, int]]]) -> bool:
"""Verify patched s_prefetch_inst_pc_rel koffsets and klengths. Returns True if all match.
For NOP-replaced entries (klength == NOP_KLENGTH_SENTINEL), verification
checks that the raw bytes at the original position are 2× s_nop encoding.
"""
import struct as _struct
data = fat_obj.read_bytes()
bundle = _find_gpu_bundle(data, tag="verify")
if bundle is None:
return False
magic_idx, gpu_off, gpu_sz, _ = bundle
abs_start = magic_idx + gpu_off
# Locate .text for raw byte diagnostics
text_info = _find_elf_text_section(data, abs_start)
text_foff = text_va = 0
if text_info:
text_foff, text_sz, text_va = text_info
objdump_text = _objdump_gpu_elf(data, abs_start, gpu_sz, mcpu, objdump_path,
fat_obj.with_suffix(".ck_verify_tmp"), tag="verify")
if objdump_text is None:
return False
prefetch_re = re.compile(
r's_prefetch_inst_pc_rel\s+(0x[0-9a-fA-F]+|\d+)\s*,\s*\S+\s*,\s*(\d+)')
current_func: str | None = None
func_pf_idx: dict[str, int] = {}
# Track VAs already consumed as part of a NOP pair so the second s_nop
# of a pair (or compiler-emitted s_nops) are not misidentified.
consumed_nop_vas: set[int] = set()
ok = True
checked = 0
for line in objdump_text.splitlines():
m = OBJDUMP_FUNC_RE.match(line)
if m:
current_func = m.group(1)
continue
if current_func and current_func in func_koffsets:
idx = func_pf_idx.get(current_func, 0)
pair_list = func_koffsets[current_func]
if idx >= len(pair_list):
continue
exp_koff, exp_klen = pair_list[idx]
if exp_klen == NOP_KLENGTH_SENTINEL:
# Expect 2× s_nop at this position. Match by checking whether
# the raw bytes at this VA form a NOP pair (both dwords are
# S_NOP_ENCODING). This avoids confusion with compiler-emitted
# s_nop instructions that are not part of our patching.
if 's_nop' in line:
m_addr = OBJDUMP_ADDR_RE.search(line)
if m_addr and text_info:
va = int(m_addr.group(1), 16)
if va in consumed_nop_vas:
continue # second nop of an already-verified pair
pos = abs_start + text_foff + (va - text_va)
if 0 <= pos and pos + 8 <= len(data):
dw0 = _struct.unpack_from('<I', data, pos)[0]
dw1 = _struct.unpack_from('<I', data, pos + 4)[0]
if dw0 == S_NOP_ENCODING and dw1 == S_NOP_ENCODING:
consumed_nop_vas.add(va)
consumed_nop_vas.add(va + 4)
status = "OK"
print(f"[verify] {status}: {current_func[:60]} prefetch[{idx}] "
f"→ 2× s_nop dw0=0x{dw0:08x} dw1=0x{dw1:08x}")
func_pf_idx[current_func] = idx + 1
checked += 1
else:
m2 = prefetch_re.search(line)
if m2:
actual_koff = int(m2.group(1), 0)
actual_klen = int(m2.group(2))
koff_ok = actual_koff == exp_koff
klen_ok = actual_klen == exp_klen
status = "OK" if koff_ok else "MISMATCH"
if not koff_ok:
ok = False
extra = ""
if not klen_ok:
extra = " (klength mismatch — may need rebuild)"
print(f"[verify] {status}: {current_func[:60]} prefetch[{idx}] "
f"koffset exp={hex(exp_koff)} act={hex(actual_koff)} "
f"klength exp={exp_klen} act={actual_klen}{extra}")
# Diagnostic: dump raw bytes from the .o at the instruction position
m_addr = OBJDUMP_ADDR_RE.search(line)
if m_addr and text_info: