LLVM 23.0.0git
MachineSMEABIPass.cpp
Go to the documentation of this file.
1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
58#include "AArch64Subtarget.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204StringRef getZAStateString(ZAState State) {
205#define MAKE_CASE(V) \
206 case V: \
207 return #V;
208 switch (State) {
209 MAKE_CASE(ZAState::ANY)
210 MAKE_CASE(ZAState::ACTIVE)
211 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
212 MAKE_CASE(ZAState::LOCAL_SAVED)
213 MAKE_CASE(ZAState::LOCAL_COMMITTED)
214 MAKE_CASE(ZAState::ENTRY)
215 MAKE_CASE(ZAState::OFF)
216 default:
217 llvm_unreachable("Unexpected ZAState");
218 }
219#undef MAKE_CASE
220}
221
222static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
223 const MachineOperand &MO) {
224 if (!MO.isReg() || !MO.getReg().isPhysical())
225 return false;
226 return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
227 return AArch64::MPR128RegClass.contains(SR) ||
228 AArch64::ZTRRegClass.contains(SR);
229 });
230}
231
232/// Returns the required ZA state needed before \p MI and an iterator pointing
233/// to where any code required to change the ZA state should be inserted.
234static std::pair<ZAState, MachineBasicBlock::iterator>
235getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
236 SMEAttrs SMEFnAttrs) {
238
239 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
240 // intended to mark the position immediately before a call. Due to
241 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
242 // so we use std::prev(InsertPt) to get the position before the call.
243
244 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
245 return {ZAState::ACTIVE, std::prev(InsertPt)};
246
247 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
248 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
249 return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
250
251 // If we only need to save ZT0 there's two cases to consider:
252 // 1. The function has ZA state (that we don't need to save).
253 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
254 // This only saves ZT0.
255 // 2. The function does not have ZA state
256 // - In this case we switch to "LOCAL_COMMITTED" state.
257 // This saves ZT0 and turns ZA off.
258 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
259 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
260 : ZAState::LOCAL_COMMITTED,
261 std::prev(InsertPt)};
262 }
263
264 if (MI.isReturn()) {
265 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
266 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
267 }
268
269 for (auto &MO : MI.operands()) {
270 if (isZAorZTRegOp(TRI, MO))
271 return {ZAState::ACTIVE, InsertPt};
272 }
273
274 return {ZAState::ANY, InsertPt};
275}
276
277struct MachineSMEABI : public MachineFunctionPass {
278 inline static char ID = 0;
279
280 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
281 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
282
283 bool runOnMachineFunction(MachineFunction &MF) override;
284
285 StringRef getPassName() const override { return "Machine SME ABI pass"; }
286
287 void getAnalysisUsage(AnalysisUsage &AU) const override {
288 AU.setPreservesCFG();
295 }
296
297 /// Collects the needed ZA state (and live registers) before each instruction
298 /// within the machine function.
299 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
300
301 /// Assigns each edge bundle a ZA state based on the desired states of
302 /// incoming and outgoing blocks in the bundle.
303 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
304 const FunctionInfo &FnInfo);
305
306 /// Inserts code to handle changes between ZA states within the function.
307 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
308 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
309 const EdgeBundles &Bundles,
310 ArrayRef<ZAState> BundleStates);
311
312 void addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
313 CallingConv::ID ExpectedCC);
314
315 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
316 MachineBasicBlock::iterator MBBI, bool IsSave);
317
318 // Emission routines for private and shared ZA functions (using lazy saves).
319 void emitSMEPrologue(MachineBasicBlock &MBB,
321 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
323 LiveRegs PhysLiveRegs);
324 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
326 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
329 bool ClearTPIDR2, bool On);
330
331 // Emission routines for agnostic ZA functions.
332 void emitSetupFullZASave(MachineBasicBlock &MBB,
334 LiveRegs PhysLiveRegs);
335 // Emit a "full" ZA save or restore. It is "full" in the sense that this
336 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
337 // handles saving and restoring both ZA and ZT0.
338 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
340 LiveRegs PhysLiveRegs, bool IsSave);
341 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
343 LiveRegs PhysLiveRegs);
344
345 /// Attempts to find an insertion point before \p Inst where the status flags
346 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
347 /// the block is found.
348 std::pair<MachineBasicBlock::iterator, LiveRegs>
349 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
351 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
352 MachineBasicBlock::iterator MBBI, ZAState From,
353 ZAState To, LiveRegs PhysLiveRegs);
354
355 // Helpers for switching between lazy/full ZA save/restore routines.
356 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
358 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
359 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
360 /*IsSave=*/true);
361 return emitSetupLazySave(Context, MBB, MBBI);
362 }
363 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
365 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
366 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
367 /*IsSave=*/false);
368 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
369 }
370 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
372 LiveRegs PhysLiveRegs) {
373 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
374 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
375 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
376 }
377
378 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
379 /// intended to be used to emit lazy save remarks. Note: This stops at the
380 /// first marked call along any path.
381 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
384 unsigned Marker) const;
385
386 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
388 unsigned Marker, StringRef RemarkName,
389 StringRef SaveName) const;
390
391 void emitError(const Twine &Message) {
392 LLVMContext &Context = MF->getFunction().getContext();
393 Context.emitError(MF->getName() + ": " + Message);
394 }
395
396 /// Save live physical registers to virtual registers.
397 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
399 /// Restore physical registers from a save of their previous values.
400 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
402
403private:
405
406 MachineFunction *MF = nullptr;
407 const AArch64Subtarget *Subtarget = nullptr;
408 const AArch64RegisterInfo *TRI = nullptr;
409 const AArch64FunctionInfo *AFI = nullptr;
410 const AArch64InstrInfo *TII = nullptr;
411 const LibcallLoweringInfo *LLI = nullptr;
412
414 MachineRegisterInfo *MRI = nullptr;
415 MachineLoopInfo *MLI = nullptr;
416};
417
418static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
419 LiveRegs PhysLiveRegs = LiveRegs::None;
420 if (!LiveUnits.available(AArch64::NZCV))
421 PhysLiveRegs |= LiveRegs::NZCV;
422 // We have to track W0 and X0 separately as otherwise things can get
423 // confused if we attempt to preserve X0 but only W0 was defined.
424 if (!LiveUnits.available(AArch64::W0))
425 PhysLiveRegs |= LiveRegs::W0;
426 if (!LiveUnits.available(AArch64::W0_HI))
427 PhysLiveRegs |= LiveRegs::W0_HI;
428 return PhysLiveRegs;
429}
430
431static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
432 if (PhysLiveRegs & LiveRegs::NZCV)
433 LiveUnits.addReg(AArch64::NZCV);
434 if (PhysLiveRegs & LiveRegs::W0)
435 LiveUnits.addReg(AArch64::W0);
436 if (PhysLiveRegs & LiveRegs::W0_HI)
437 LiveUnits.addReg(AArch64::W0_HI);
438}
439
440[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
441 switch (Opc) {
442 case AArch64::TLSDESC_CALLSEQ:
443 case AArch64::TLSDESC_AUTH_CALLSEQ:
444 case AArch64::ADJCALLSTACKDOWN:
445 return true;
446 default:
447 return false;
448 }
449}
450
451FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
452 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
453 SMEFnAttrs.hasZAState()) &&
454 "Expected function to have ZA/ZT0 state!");
455
457 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
458 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
459
460 for (MachineBasicBlock &MBB : *MF) {
461 BlockInfo &Block = Blocks[MBB.getNumber()];
462
463 if (MBB.isEntryBlock()) {
464 // Entry block:
465 Block.FixedEntryState = ZAState::ENTRY;
466 } else if (MBB.isEHPad()) {
467 // EH entry block:
468 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
469 }
470
471 LiveRegUnits LiveUnits(*TRI);
472 LiveUnits.addLiveOuts(MBB);
473
474 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
475 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
476 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
477 for (MachineInstr &MI : reverse(MBB)) {
478 if (MI.isDebugInstr())
479 continue;
480
482 LiveUnits.stepBackward(MI);
483 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
484 // The SMEStateAllocPseudo marker is added to a function if the save
485 // buffer was allocated in SelectionDAG. It marks the end of the
486 // allocation -- which is a safe point for this pass to insert any TPIDR2
487 // block setup.
488 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
489 AfterSMEProloguePt = MBBI;
490 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
491 }
492 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
493 auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
494 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
495 "Unexpected state change insertion point!");
496 if (MBBI == FirstTerminatorInsertPt)
497 Block.PhysLiveRegsAtExit = PhysLiveRegs;
498 if (MBBI == FirstNonPhiInsertPt)
499 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
500 if (NeededState != ZAState::ANY)
501 Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
502 }
503
504 // Reverse vector (as we had to iterate backwards for liveness).
505 std::reverse(Block.Insts.begin(), Block.Insts.end());
506
507 // Record the desired states on entry/exit of this block. These are the
508 // states that would not incur a state transition.
509 if (!Block.Insts.empty()) {
510 Block.DesiredIncomingState = Block.Insts.front().NeededState;
511 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
512 }
513 }
514
515 return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
516 PhysLiveRegsAfterSMEPrologue};
517}
518
519/// Assigns each edge bundle a ZA state based on the desired states of incoming
520/// and outgoing blocks in the bundle.
522MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
523 const FunctionInfo &FnInfo) {
524 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
525 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
526 std::optional<ZAState> BundleState;
527 for (unsigned BlockID : Bundles.getBlocks(I)) {
528 const BlockInfo &Block = FnInfo.Blocks[BlockID];
529 // Check if the block is an incoming block in the bundle. Note: We skip
530 // Block.FixedEntryState != ANY to ignore EH pads (which are only
531 // reachable via exceptions).
532 if (Block.FixedEntryState != ZAState::ANY ||
533 Bundles.getBundle(BlockID, /*Out=*/false) != I)
534 continue;
535
536 // Pick a state that matches all incoming blocks. Fall back to "ACTIVE" if
537 // any incoming state doesn't match. This will hoist the state from
538 // incoming blocks to outgoing blocks.
539 if (!BundleState)
540 BundleState = Block.DesiredIncomingState;
541 else if (BundleState != Block.DesiredIncomingState)
542 BundleState = ZAState::ACTIVE;
543 }
544
545 if (!BundleState || BundleState == ZAState::ANY)
546 BundleState = ZAState::ACTIVE;
547
548 BundleStates[I] = *BundleState;
549 }
550
551 return BundleStates;
552}
553
554std::pair<MachineBasicBlock::iterator, LiveRegs>
555MachineSMEABI::findStateChangeInsertionPoint(
556 MachineBasicBlock &MBB, const BlockInfo &Block,
558 LiveRegs PhysLiveRegs;
560 if (Inst != Block.Insts.end()) {
561 InsertPt = Inst->InsertPt;
562 PhysLiveRegs = Inst->PhysLiveRegs;
563 } else {
564 InsertPt = MBB.getFirstTerminator();
565 PhysLiveRegs = Block.PhysLiveRegsAtExit;
566 }
567
568 if (PhysLiveRegs == LiveRegs::None)
569 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
570
571 // Find the previous state change. We can not move before this point.
572 MachineBasicBlock::iterator PrevStateChangeI;
573 if (Inst == Block.Insts.begin()) {
574 PrevStateChangeI = MBB.begin();
575 } else {
576 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
577 // InstInfo object for instructions that require a specific ZA state, so the
578 // InstInfo is the site of the previous state change in the block (which can
579 // be several MIs earlier).
580 PrevStateChangeI = std::prev(Inst)->InsertPt;
581 }
582
583 // Note: LiveUnits will only accurately track X0 and NZCV.
584 LiveRegUnits LiveUnits(*TRI);
585 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
586 auto BestCandidate = std::make_pair(InsertPt, PhysLiveRegs);
587 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
588 // Don't move before/into a call (which may have a state change before it).
589 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
590 break;
591 LiveUnits.stepBackward(*I);
592 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
593 // Find places where NZCV is available, but keep looking for locations where
594 // both NZCV and X0 are available, which can avoid some copies.
595 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
596 BestCandidate = {I, CurrentPhysLiveRegs};
597 if (CurrentPhysLiveRegs == LiveRegs::None)
598 break;
599 }
600 return BestCandidate;
601}
602
603void MachineSMEABI::insertStateChanges(EmitContext &Context,
604 const FunctionInfo &FnInfo,
605 const EdgeBundles &Bundles,
606 ArrayRef<ZAState> BundleStates) {
607 for (MachineBasicBlock &MBB : *MF) {
608 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
609 ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
610 /*Out=*/false)];
611
612 ZAState CurrentState = Block.FixedEntryState;
613 if (CurrentState == ZAState::ANY)
614 CurrentState = InState;
615
616 for (auto &Inst : Block.Insts) {
617 if (CurrentState != Inst.NeededState) {
618 auto [InsertPt, PhysLiveRegs] =
619 findStateChangeInsertionPoint(MBB, Block, &Inst);
620 emitStateChange(Context, MBB, InsertPt, CurrentState, Inst.NeededState,
621 PhysLiveRegs);
622 CurrentState = Inst.NeededState;
623 }
624 }
625
626 if (MBB.succ_empty())
627 continue;
628
629 ZAState OutState =
630 BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
631 if (CurrentState != OutState) {
632 auto [InsertPt, PhysLiveRegs] =
633 findStateChangeInsertionPoint(MBB, Block, Block.Insts.end());
634 emitStateChange(Context, MBB, InsertPt, CurrentState, OutState,
635 PhysLiveRegs);
636 }
637 }
638}
639
642 if (MBB.empty())
643 return DebugLoc();
644 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
645}
646
647/// Finds the first call (as determined by MachineInstr::isCall()) starting from
648/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
649/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
650/// and the function returns true.
651static bool findMarkedCall(const MachineBasicBlock &MBB,
654 unsigned Marker, unsigned CallDestroyOpcode) {
655 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
656 auto MarkerInst = std::find_if(MBBI, MBB.end(), IsMarker);
657 if (MarkerInst == MBB.end())
658 return false;
660 while (++I != MBB.end()) {
661 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
662 break;
663 }
664 if (I != MBB.end() && I->isCall())
665 Calls.push_back(&*I);
666 // Note: This function always returns true if a "Marker" was found.
667 return true;
668}
669
670void MachineSMEABI::collectReachableMarkedCalls(
671 const MachineBasicBlock &StartMBB,
673 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
674 assert(Marker == AArch64::InOutZAUsePseudo ||
675 Marker == AArch64::RequiresZASavePseudo ||
676 Marker == AArch64::RequiresZT0SavePseudo);
677 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
678 if (findMarkedCall(StartMBB, StartInst, Calls, Marker, CallDestroyOpcode))
679 return;
680
683 StartMBB.succ_rend());
684 while (!Worklist.empty()) {
685 const MachineBasicBlock *MBB = Worklist.pop_back_val();
686 auto [_, Inserted] = Visited.insert(MBB);
687 if (!Inserted)
688 continue;
689
690 if (!findMarkedCall(*MBB, MBB->begin(), Calls, Marker, CallDestroyOpcode))
691 Worklist.append(MBB->succ_rbegin(), MBB->succ_rend());
692 }
693}
694
695static StringRef getCalleeName(const MachineInstr &CallInst) {
696 assert(CallInst.isCall() && "expected a call");
697 for (const MachineOperand &MO : CallInst.operands()) {
698 if (MO.isSymbol())
699 return MO.getSymbolName();
700 if (MO.isGlobal())
701 return MO.getGlobal()->getName();
702 }
703 return {};
704}
705
706void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
708 DebugLoc DL, unsigned Marker,
709 StringRef RemarkName,
710 StringRef SaveName) const {
711 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
712 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
713 };
714 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
715 ORE->emit([&] {
716 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
717 << " emitted in '" << MF->getName() << "'";
718 });
719 if (!ORE->allowExtraAnalysis("sme"))
720 return;
721 SmallVector<const MachineInstr *> CallsRequiringSaves;
722 collectReachableMarkedCalls(MBB, MBBI, CallsRequiringSaves, Marker);
723 for (const MachineInstr *CallInst : CallsRequiringSaves) {
724 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
725 R << "call";
726 if (StringRef CalleeName = getCalleeName(*CallInst); !CalleeName.empty())
727 R << " to '" << CalleeName << "'";
728 R << " requires " << StateName << " save";
729 ORE->emit(R);
730 }
731}
732
733void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
737
738 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
739 "SMELazySaveZA", "lazy save");
740
741 // Get pointer to TPIDR2 block.
742 Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
743 Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
744 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
745 .addFrameIndex(Context.getTPIDR2Block(*MF))
746 .addImm(0)
747 .addImm(0);
748 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
749 .addReg(TPIDR2);
750 // Set TPIDR2_EL0 to point to TPIDR2 block.
751 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
752 .addImm(AArch64SysReg::TPIDR2_EL0)
753 .addReg(TPIDR2Ptr);
754}
755
756PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
759 DebugLoc DL) {
760 PhysRegSave RegSave{PhysLiveRegs};
761 if (PhysLiveRegs & LiveRegs::NZCV) {
762 RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
763 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags)
764 .addImm(AArch64SysReg::NZCV)
765 .addReg(AArch64::NZCV, RegState::Implicit);
766 }
767 // Note: Preserving X0 is "free" as this is before register allocation, so
768 // the register allocator is still able to optimize these copies.
769 if (PhysLiveRegs & LiveRegs::W0) {
770 RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
771 ? &AArch64::GPR64RegClass
772 : &AArch64::GPR32RegClass);
773 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save)
774 .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
775 }
776 return RegSave;
777}
778
779void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
782 DebugLoc DL) {
783 if (RegSave.StatusFlags != AArch64::NoRegister)
784 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
785 .addImm(AArch64SysReg::NZCV)
786 .addReg(RegSave.StatusFlags)
787 .addReg(AArch64::NZCV, RegState::ImplicitDefine);
788
789 if (RegSave.X0Save != AArch64::NoRegister)
790 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY),
791 RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
792 .addReg(RegSave.X0Save);
793}
794
795void MachineSMEABI::addSMELibCall(MachineInstrBuilder &MIB, RTLIB::Libcall LC,
796 CallingConv::ID ExpectedCC) {
797 RTLIB::LibcallImpl LCImpl = LLI->getLibcallImpl(LC);
798 if (LCImpl == RTLIB::Unsupported)
799 emitError("cannot lower SME ABI (SME routines unsupported)");
802 if (CC != ExpectedCC)
803 emitError("invalid calling convention for SME routine: '" + ImplName + "'");
804 // FIXME: This assumes the ImplName StringRef is null-terminated.
805 MIB.addExternalSymbol(ImplName.data());
806 MIB.addRegMask(TRI->getCallPreservedMask(*MF, CC));
807}
808
809void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
812 LiveRegs PhysLiveRegs) {
814 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
815 Register TPIDR2 = AArch64::X0;
816
817 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
818 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
819
820 // Enable ZA.
821 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
822 .addImm(AArch64SVCR::SVCRZA)
823 .addImm(1);
824 // Get current TPIDR2_EL0.
825 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0)
826 .addImm(AArch64SysReg::TPIDR2_EL0);
827 // Get pointer to TPIDR2 block.
828 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
829 .addFrameIndex(Context.getTPIDR2Block(*MF))
830 .addImm(0)
831 .addImm(0);
832 // (Conditionally) restore ZA state.
833 auto RestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo))
834 .addReg(TPIDR2EL0)
835 .addReg(TPIDR2);
836 addSMELibCall(
837 RestoreZA, RTLIB::SMEABI_TPIDR2_RESTORE,
839 // Zero TPIDR2_EL0.
840 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
841 .addImm(AArch64SysReg::TPIDR2_EL0)
842 .addReg(AArch64::XZR);
843
844 restorePhyRegSave(RegSave, MBB, MBBI, DL);
845}
846
847void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
849 bool ClearTPIDR2, bool On) {
851
852 if (ClearTPIDR2)
853 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
854 .addImm(AArch64SysReg::TPIDR2_EL0)
855 .addReg(AArch64::XZR);
856
857 // Disable ZA.
858 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
859 .addImm(AArch64SVCR::SVCRZA)
860 .addImm(On ? 1 : 0);
861}
862
863void MachineSMEABI::emitAllocateLazySaveBuffer(
864 EmitContext &Context, MachineBasicBlock &MBB,
866 MachineFrameInfo &MFI = MF->getFrameInfo();
868 Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
869 Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
870 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
871
872 // Calculate SVL.
873 BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1);
874
875 // 1. Allocate the lazy save buffer.
876 if (Buffer == AArch64::NoRegister) {
877 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
878 // Buffer != AArch64::NoRegister). This is done to reuse the existing
879 // expansions (which can insert stack checks). This works, but it means we
880 // will always allocate the lazy save buffer (even if the function contains
881 // no lazy saves). If we want to handle Windows here, we'll need to
882 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
883 assert(!Subtarget->isTargetWindows() &&
884 "Lazy ZA save is not yet supported on Windows");
885 Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
886 // Get original stack pointer.
887 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP)
888 .addReg(AArch64::SP);
889 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
890 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer)
891 .addReg(SVL)
892 .addReg(SVL)
893 .addReg(SP);
894 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP)
895 .addReg(Buffer);
896 // We have just allocated a variable sized object, tell this to PEI.
897 MFI.CreateVariableSizedObject(Align(16), nullptr);
898 }
899
900 // 2. Setup the TPIDR2 block.
901 {
902 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
903 // generally don't support big-endian SVE/SME.
904 if (!Subtarget->isLittleEndian())
906 "TPIDR2 block initialization is not supported on big-endian targets");
907
908 // Store buffer pointer and num_za_save_slices.
909 // Bytes 10-15 are implicitly zeroed.
910 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
911 .addReg(Buffer)
912 .addReg(SVL)
913 .addFrameIndex(Context.getTPIDR2Block(*MF))
914 .addImm(0);
915 }
916}
917
918static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
919
920void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
923
924 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
925 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
927 // Get current TPIDR2_EL0.
928 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
929 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
930 .addReg(TPIDR2EL0, RegState::Define)
931 .addImm(AArch64SysReg::TPIDR2_EL0);
932 // If TPIDR2_EL0 is non-zero, commit the lazy save.
933 // NOTE: Functions that only use ZT0 don't need to zero ZA.
934 auto CommitZASave =
935 BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
936 .addReg(TPIDR2EL0)
937 .addImm(ZeroZA)
938 .addImm(ZeroZT0);
939 addSMELibCall(
940 CommitZASave, RTLIB::SMEABI_TPIDR2_SAVE,
942 if (ZeroZA)
943 CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
944 if (ZeroZT0)
945 CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
946 // Enable ZA (as ZA could have previously been in the OFF state).
947 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
948 .addImm(AArch64SVCR::SVCRZA)
949 .addImm(1);
950 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
951 if (ZeroZA)
952 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
954 .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
955 if (ZeroZT0)
956 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
957 }
958}
959
960void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
963 LiveRegs PhysLiveRegs, bool IsSave) {
965
966 if (IsSave)
967 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
968 "SMEFullZASave", "full save");
969
970 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
971
972 // Copy the buffer pointer into X0.
973 Register BufferPtr = AArch64::X0;
974 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
975 .addReg(Context.getAgnosticZABufferPtr(*MF));
976
977 // Call __arm_sme_save/__arm_sme_restore.
978 auto SaveRestoreZA = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
979 .addReg(BufferPtr, RegState::Implicit);
980 addSMELibCall(
981 SaveRestoreZA,
982 IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE,
984
985 restorePhyRegSave(RegSave, MBB, MBBI, DL);
986}
987
988void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
991 bool IsSave) {
993
994 // Note: This will report calls that _only_ need ZT0 saved. Call that save
995 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
996 // reporting the same calls twice.
997 if (IsSave)
998 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZT0SavePseudo,
999 "SMEZT0Save", "spill");
1000
1001 Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
1002
1003 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
1004 .addFrameIndex(Context.getZT0SaveSlot(*MF))
1005 .addImm(0)
1006 .addImm(0);
1007
1008 if (IsSave) {
1009 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
1010 .addReg(AArch64::ZT0)
1011 .addReg(ZT0Save);
1012 } else {
1013 BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
1014 .addReg(ZT0Save);
1015 }
1016}
1017
1018void MachineSMEABI::emitAllocateFullZASaveBuffer(
1019 EmitContext &Context, MachineBasicBlock &MBB,
1021 // Buffer already allocated in SelectionDAG.
1022 if (AFI->getEarlyAllocSMESaveBuffer())
1023 return;
1024
1026 Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
1027 Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
1028
1029 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1030
1031 // Calculate the SME state size.
1032 {
1033 auto SMEStateSize = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
1034 .addReg(AArch64::X0, RegState::ImplicitDefine);
1035 addSMELibCall(
1036 SMEStateSize, RTLIB::SMEABI_SME_STATE_SIZE,
1038 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
1039 .addReg(AArch64::X0);
1040 }
1041
1042 // Allocate a buffer object of the size given __arm_sme_state_size.
1043 {
1044 MachineFrameInfo &MFI = MF->getFrameInfo();
1045 BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
1046 .addReg(AArch64::SP)
1047 .addReg(BufferSize)
1049 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
1050 .addReg(AArch64::SP);
1051
1052 // We have just allocated a variable sized object, tell this to PEI.
1053 MFI.CreateVariableSizedObject(Align(16), nullptr);
1054 }
1055
1056 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1057}
1058
1059struct FromState {
1060 ZAState From;
1061
1062 constexpr uint8_t to(ZAState To) const {
1063 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1064 return uint8_t(From) << 4 | uint8_t(To);
1065 }
1066};
1067
1068constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1069
1070void MachineSMEABI::emitStateChange(EmitContext &Context,
1073 ZAState From, ZAState To,
1074 LiveRegs PhysLiveRegs) {
1075 // ZA not used.
1076 if (From == ZAState::ANY || To == ZAState::ANY)
1077 return;
1078
1079 // If we're exiting from the ENTRY state that means that the function has not
1080 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1081 if (From == ZAState::ENTRY && To == ZAState::OFF)
1082 return;
1083
1084 // TODO: Avoid setting up the save buffer if there's no transition to
1085 // LOCAL_SAVED.
1086 if (From == ZAState::ENTRY) {
1087 assert(&MBB == &MBB.getParent()->front() &&
1088 "ENTRY state only valid in entry block");
1089 emitSMEPrologue(MBB, MBB.getFirstNonPHI());
1090 if (To == ZAState::ACTIVE)
1091 return; // Nothing more to do (ZA is active after the prologue).
1092
1093 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1094 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1095 // case by changing the placement of the zero instruction.
1096 From = ZAState::ACTIVE;
1097 }
1098
1099 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1100 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1101 bool HasZT0State = SMEFnAttrs.hasZT0State();
1102 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1103
1104 switch (transitionFrom(From).to(To)) {
1105 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1106 case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1107 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1108 break;
1109 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1110 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1111 break;
1112
1113 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1114 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1115 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
1116 if (HasZT0State && From == ZAState::ACTIVE)
1117 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1118 if (HasZAState)
1119 emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1120 break;
1121
1122 // This section handles: ACTIVE -> LOCAL_COMMITTED
1123 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1124 // TODO: We could support ZA state here, but this transition is currently
1125 // only possible when we _don't_ have ZA state.
1126 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1127 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1128 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1129 break;
1130
1131 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1132 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1133 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1134 // These transitions are a no-op.
1135 break;
1136
1137 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1138 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1139 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1140 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1141 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE_ZT0_SAVED):
1142 if (HasZAState)
1143 emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1144 else
1145 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1146 if (HasZT0State && To == ZAState::ACTIVE)
1147 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1148 break;
1149
1150 // This section handles transitions to OFF (not previously covered)
1151 case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
1152 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
1153 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
1154 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1155 "Did not expect to turn ZA off in shared/agnostic ZA function");
1156 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1157 /*On=*/false);
1158 break;
1159
1160 default:
1161 dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
1162 << getZAStateString(To) << '\n';
1163 llvm_unreachable("Unimplemented state transition");
1164 }
1165}
1166
1167/// Returns true if private ZA setup can be elided. This occurs when there is
1168/// no instruction within the function that requires ZA to be active.
1169static bool canElidePrivateZASetup(const FunctionInfo &FnInfo) {
1170 for (const BlockInfo &BlockInfo : FnInfo.Blocks) {
1171 for (const InstInfo &InstInfo : BlockInfo.Insts) {
1172 if (InstInfo.NeededState == ZAState::ACTIVE ||
1173 InstInfo.NeededState == ZAState::ACTIVE_ZT0_SAVED)
1174 return false;
1175 }
1176 }
1177 return true;
1178}
1179
1180} // end anonymous namespace
1181
1182INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1183 false, false)
1184
1185bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1186 AFI = MF.getInfo<AArch64FunctionInfo>();
1187 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1188 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1189 !SMEFnAttrs.hasAgnosticZAInterface())
1190 return false;
1191
1192 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1193 if (!Subtarget->hasSME() && !SMEFnAttrs.hasAgnosticZAInterface())
1194 return false;
1195
1196 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1197
1198 this->MF = &MF;
1199 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1200 LLI = &getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
1201 *MF.getFunction().getParent(), *Subtarget);
1202 TII = Subtarget->getInstrInfo();
1203 TRI = Subtarget->getRegisterInfo();
1204 MRI = &MF.getRegInfo();
1205
1206 const EdgeBundles &Bundles =
1207 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1208
1209 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1210
1211 if (SMEFnAttrs.hasPrivateZAInterface() && canElidePrivateZASetup(FnInfo))
1212 return false;
1213
1214 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1215
1216 EmitContext Context;
1217 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1218
1219 if (Context.needsSaveBuffer()) {
1220 if (FnInfo.AfterSMEProloguePt) {
1221 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1222 // entry block (due to the probing loop).
1223 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1224 emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
1225 FnInfo.PhysLiveRegsAfterSMEPrologue);
1226 } else {
1227 MachineBasicBlock &EntryBlock = MF.front();
1228 emitAllocateZASaveBuffer(
1229 Context, EntryBlock, EntryBlock.getFirstNonPHI(),
1230 FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1231 }
1232 }
1233
1234 return true;
1235}
1236
1238 return new MachineSMEABI(OptLevel);
1239}
static constexpr unsigned ZERO_ALL_ZA_MASK
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
MachineBasicBlock MachineBasicBlock::iterator MBBI
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
const HexagonInstrInfo * TII
#define _
IRTranslator LLVM IR MI
This file implements the LivePhysRegs utility for tracking liveness of physical registers.
#define ENTRY(ASMNAME, ENUM)
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
===- MachineOptimizationRemarkEmitter.h - Opt Diagnostics -*- C++ -*-—===//
#define MAKE_CASE(V)
Register const TargetRegisterInfo * TRI
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
Represent the analysis usage information of a pass.
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
This class represents a function call, abstracting a target machine's calling convention.
A debug info location.
Definition DebugLoc.h:123
ArrayRef< unsigned > getBlocks(unsigned Bundle) const
getBlocks - Return an array of blocks that are connected to Bundle.
Definition EdgeBundles.h:53
unsigned getBundle(unsigned N, bool Out) const
getBundle - Return the ingoing (Out = false) or outgoing (Out = true) bundle number for basic block N
Definition EdgeBundles.h:47
unsigned getNumBundles() const
getNumBundles - Return the total number of bundles in the CFG.
Definition EdgeBundles.h:50
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
Tracks which library functions to use for a particular subtarget.
LLVM_ABI CallingConv::ID getLibcallImplCallingConv(RTLIB::LibcallImpl Call) const
Get the CallingConv that should be used for the specified libcall.
LLVM_ABI RTLIB::LibcallImpl getLibcallImpl(RTLIB::Libcall Call) const
Return the lowering's selection of implementation call for Call.
A set of register units used to track register liveness.
bool available(MCRegister Reg) const
Returns true if no part of physical register Reg is live.
void addReg(MCRegister Reg)
Adds register units covered by physical register Reg.
LLVM_ABI void stepBackward(const MachineInstr &MI)
Updates liveness when stepping backwards over the instruction MI.
LLVM_ABI void addLiveOuts(const MachineBasicBlock &MBB)
Adds registers living out of block MBB.
MachineInstrBundleIterator< const MachineInstr > const_iterator
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
LLVM_ABI iterator getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
succ_reverse_iterator succ_rbegin()
MachineInstrBundleIterator< MachineInstr > iterator
succ_reverse_iterator succ_rend()
The MachineFrameInfo class represents an abstract stack frame until prolog/epilog code is inserted.
LLVM_ABI int CreateStackObject(uint64_t Size, Align Alignment, bool isSpillSlot, const AllocaInst *Alloca=nullptr, uint8_t ID=0)
Create a new statically sized stack object, returning a nonnegative identifier to represent it.
LLVM_ABI int CreateSpillStackObject(uint64_t Size, Align Alignment)
Create a new statically sized stack object that represents a spill slot, returning a nonnegative iden...
LLVM_ABI int CreateVariableSizedObject(Align Alignment, const AllocaInst *Alloca)
Notify the MachineFrameInfo object that a variable sized object has been created.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineFrameInfo & getFrameInfo()
getFrameInfo - Return the frame info object for the current function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
unsigned getNumBlockIDs() const
getNumBlockIDs - Return the number of MBB ID's allocated.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineInstrBuilder & addExternalSymbol(const char *FnName, unsigned TargetFlags=0) const
const MachineInstrBuilder & addReg(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
const MachineInstrBuilder & addFrameIndex(int Idx) const
const MachineInstrBuilder & addRegMask(const uint32_t *Mask) const
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isSymbol() const
isSymbol - Tests if this is a MO_ExternalSymbol operand.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
const char * getSymbolName() const
Register getReg() const
getReg - Returns the register number.
Diagnostic information for optimization analysis remarks.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Emit an optimization remark.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to be more informative.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition Register.h:83
SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
bool hasAgnosticZAInterface() const
bool hasPrivateZAInterface() const
bool hasSharedZAInterface() const
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
typename SuperClass::const_iterator const_iterator
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
constexpr bool empty() const
Check if the string is empty.
Definition StringRef.h:141
constexpr const char * data() const
Get a pointer to the start of the string (which may not be null terminated).
Definition StringRef.h:138
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
op_range operands()
Definition User.h:267
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static unsigned getArithExtendImm(AArch64_AM::ShiftExtendType ET, unsigned Imm)
getArithExtendImm - Encode the extend type and shift amount for an arithmetic instruction: imm: 3-bit...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0
Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15.
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1
Preserve X1-X15, X19-X29, SP, Z0-Z31, P0-P15.
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
@ Implicit
Not emitted register (e.g. carry, or temporary result).
@ Define
Register definition.
FunctionPass * createMachineSMEABIPass(CodeGenOptLevel)
LLVM_ABI char & MachineDominatorsID
MachineDominators - This pass is a machine dominators analysis pass.
LLVM_ABI void reportFatalInternalError(Error Err)
Report a fatal error that indicates a bug in LLVM.
Definition Error.cpp:173
LLVM_ABI char & MachineLoopInfoID
MachineLoopInfo - This pass is a loop analysis pass.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
Definition MCRegister.h:21
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
static StringRef getLibcallImplName(RTLIB::LibcallImpl CallImpl)
Get the libcall routine name for the specified libcall implementation.