LLVM 22.0.0git
ReOptimizeLayer.cpp
Go to the documentation of this file.
3
4using namespace llvm;
5using namespace orc;
6
7bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8 std::unique_lock<std::mutex> Lock(Mutex);
9 if (Reoptimizing)
10 return false;
11
12 Reoptimizing = true;
13 return true;
14}
15
16void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17 std::unique_lock<std::mutex> Lock(Mutex);
18 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19 Reoptimizing = false;
20 CurVersion++;
21}
22
23void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24 std::unique_lock<std::mutex> Lock(Mutex);
25 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26 Reoptimizing = false;
27}
28
30 shared::CWrapperFunctionBuffer (*JITDispatch)(void *Ctx, void *Tag,
31 const char *Data,
32 size_t Size),
33 void *JITDispatchCtx, void *Tag, uint64_t MUID, uint32_t CurVersion) {
34 // Serialize the arguments into a WrapperFunctionBuffer and call dispatch.
36 auto ArgBytes =
37 shared::WrapperFunctionBuffer::allocate(SPSArgs::size(MUID, CurVersion));
38 shared::SPSOutputBuffer OB(ArgBytes.data(), ArgBytes.size());
39 if (!SPSArgs::serialize(OB, MUID, CurVersion)) {
40 errs()
41 << "Reoptimization error: could not serialize reoptimization arguments";
42 abort();
43 }
45 JITDispatch(JITDispatchCtx, Tag, ArgBytes.data(), ArgBytes.size())};
46
47 if (const char *ErrMsg = Buf.getOutOfBandError()) {
48 errs() << "Reoptimization error: " << ErrMsg << "\naborting.\n";
49 abort();
50 }
51}
52
54 const DataLayout &DL) {
55 auto Ctx = std::make_unique<LLVMContext>();
56 auto Mod = std::make_unique<Module>("orc-rt-lite-reoptimize.ll", *Ctx);
57 Mod->setDataLayout(DL);
58
59 IRBuilder<> Builder(*Ctx);
60
61 // Create basic types portably
62 Type *VoidTy = Type::getVoidTy(*Ctx);
63 Type *Int8Ty = Type::getInt8Ty(*Ctx);
65 Type *Int64Ty = Type::getInt64Ty(*Ctx);
66 Type *VoidPtrTy = PointerType::getUnqual(*Ctx);
67
68 // Helper function type: void (void*, void*, void*, uint64_t, uint32_t)
69 FunctionType *HelperFnTy = FunctionType::get(
70 VoidTy, {VoidPtrTy, VoidPtrTy, VoidPtrTy, Int64Ty, Int32Ty}, false);
71
72 // Define ReoptimizeTag with initializer = 0
73 GlobalVariable *ReoptimizeTag = new GlobalVariable(
74 *Mod, Int8Ty, false, GlobalValue::ExternalLinkage,
75 ConstantInt::get(Int8Ty, 0), "__orc_rt_reoptimize_tag");
76
77 // Define orc_rt_lite_reoptimize function: void (uint64_t, uint32_t)
78 FunctionType *ReOptimizeFnTy =
79 FunctionType::get(VoidTy, {Int64Ty, Int32Ty}, false);
80
81 Function *ReOptimizeFn =
83 "__orc_rt_reoptimize", Mod.get());
84
85 // Set parameter names
86 auto ArgIt = ReOptimizeFn->arg_begin();
87 Value *MUID = &*ArgIt++;
88 MUID->setName("MUID");
89 Value *CurVersion = &*ArgIt;
90 CurVersion->setName("CurVersion");
91
92 // Build function body
93 BasicBlock *Entry = BasicBlock::Create(*Ctx, "entry", ReOptimizeFn);
94 Builder.SetInsertPoint(Entry);
95
96 // Create absolute address constants
97 auto &JDI = PlatformJD.getExecutionSession()
100
101 Type *IntPtrTy = DL.getIntPtrType(*Ctx);
102 Constant *JITDispatchPtr = ConstantExpr::getIntToPtr(
103 ConstantInt::get(IntPtrTy, JDI.JITDispatchFunction.getValue()),
104 VoidPtrTy);
105 Constant *JITDispatchCtxPtr = ConstantExpr::getIntToPtr(
106 ConstantInt::get(IntPtrTy, JDI.JITDispatchContext.getValue()), VoidPtrTy);
107 Constant *HelperFnAddr = ConstantExpr::getIntToPtr(
108 ConstantInt::get(IntPtrTy, reinterpret_cast<uintptr_t>(
111
112 // Cast ReoptimizeTag to void*
113 Value *ReoptimizeTagPtr = Builder.CreatePointerCast(ReoptimizeTag, VoidPtrTy);
114
115 // Call the helper function
116 Builder.CreateCall(
117 HelperFnTy, HelperFnAddr,
118 {JITDispatchPtr, JITDispatchCtxPtr, ReoptimizeTagPtr, MUID, CurVersion});
119
120 // Return void
121 Builder.CreateRetVoid();
122
123 return BaseLayer.add(PlatformJD,
124 ThreadSafeModule(std::move(Mod), std::move(Ctx)));
125}
126
129 using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
130 WFs[Mangle("__orc_rt_reoptimize_tag")] =
131 ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this,
132 &ReOptimizeLayer::rt_reoptimize);
133 return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs));
134}
135
136void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
137 ThreadSafeModule TSM) {
138 auto &JD = R->getTargetJITDylib();
139
140 bool HasNonCallable = false;
141 for (auto &KV : R->getSymbols()) {
142 auto &Flags = KV.second;
143 if (!Flags.isCallable())
144 HasNonCallable = true;
145 }
146
147 if (HasNonCallable) {
148 BaseLayer.emit(std::move(R), std::move(TSM));
149 return;
150 }
151
152 auto &MUState = createMaterializationUnitState(TSM);
153
154 if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) {
155 registerMaterializationUnitResource(Key, MUState);
156 })) {
157 ES.reportError(std::move(Err));
158 R->failMaterialization();
159 return;
160 }
161
162 if (auto Err =
163 ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
164 ES.reportError(std::move(Err));
165 R->failMaterialization();
166 return;
167 }
168
169 auto InitialDests =
170 emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM));
171 if (!InitialDests) {
172 ES.reportError(InitialDests.takeError());
173 R->failMaterialization();
174 return;
175 }
176
177 RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests));
178}
179
182 unsigned CurVersion,
183 ThreadSafeModule &TSM) {
184 return TSM.withModuleDo([&](Module &M) -> Error {
185 Type *I64Ty = Type::getInt64Ty(M.getContext());
186 GlobalVariable *Counter = new GlobalVariable(
187 M, I64Ty, false, GlobalValue::InternalLinkage,
188 Constant::getNullValue(I64Ty), "__orc_reopt_counter");
189 for (auto &F : M) {
190 if (F.isDeclaration())
191 continue;
192 auto &BB = F.getEntryBlock();
193 auto *IP = &*BB.getFirstInsertionPt();
194 IRBuilder<> IRB(IP);
195 Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true);
196 Value *Cnt = IRB.CreateLoad(I64Ty, Counter);
197 // Use EQ to prevent further reoptimize calls.
198 Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold);
199 Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1));
200 (void)IRB.CreateStore(Added, Counter);
201 Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false);
202 createReoptimizeCall(M, *SplitTerminator, MUID, CurVersion);
203 }
204 return Error::success();
205 });
206}
207
209ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
211 ThreadSafeModule TSM) {
213 cantFail(TSM.withModuleDo([&](Module &M) -> Error {
214 MangleAndInterner Mangle(ES, M.getDataLayout());
215 for (auto &F : M)
216 if (!F.isDeclaration()) {
217 std::string NewName =
218 (F.getName() + ".__def__." + Twine(Version)).str();
219 RenamedMap[Mangle(F.getName())] = Mangle(NewName);
220 F.setName(NewName);
221 }
222 return Error::success();
223 }));
224
225 auto RT = JD.createResourceTracker();
226 if (auto Err =
227 JD.define(std::make_unique<BasicIRLayerMaterializationUnit>(
228 BaseLayer, *getManglingOptions(), std::move(TSM)),
229 RT))
230 return Err;
231 MUState.setResourceTracker(RT);
232
233 SymbolLookupSet LookupSymbols;
234 for (auto [K, V] : RenamedMap)
235 LookupSymbols.add(V);
236
237 auto ImplSymbols =
238 ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols,
240 if (auto Err = ImplSymbols.takeError())
241 return Err;
242
244 for (auto [K, V] : RenamedMap)
245 Result[K] = (*ImplSymbols)[V];
246
247 return Result;
248}
249
250void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
252 uint32_t CurVersion) {
253 auto &MUState = getMaterializationUnitState(MUID);
254 if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
255 SendResult(Error::success());
256 return;
257 }
258
259 ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule());
260 auto OldRT = MUState.getResourceTracker();
261 auto &JD = OldRT->getJITDylib();
262
263 if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
264 ES.reportError(std::move(Err));
265 MUState.reoptimizeFailed();
266 SendResult(Error::success());
267 return;
268 }
269
270 auto SymbolDests =
271 emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM));
272 if (!SymbolDests) {
273 ES.reportError(SymbolDests.takeError());
274 MUState.reoptimizeFailed();
275 SendResult(Error::success());
276 return;
277 }
278
279 if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) {
280 ES.reportError(std::move(Err));
281 MUState.reoptimizeFailed();
282 SendResult(Error::success());
283 return;
284 }
285
286 MUState.reoptimizeSucceeded();
287 SendResult(Error::success());
288}
289
292 uint32_t CurVersion) {
293 Type *MUIDTy = IntegerType::get(M.getContext(), 64);
294 Type *VersionTy = IntegerType::get(M.getContext(), 32);
295 Function *ReoptimizeFunc = M.getFunction("__orc_rt_reoptimize");
296 if (!ReoptimizeFunc) {
297 std::vector<Type *> ArgTys = {MUIDTy, VersionTy};
298 FunctionType *FuncTy =
299 FunctionType::get(Type::getVoidTy(M.getContext()), ArgTys, false);
300 ReoptimizeFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage,
301 "__orc_rt_reoptimize", &M);
302 }
303 Constant *MUIDArg = ConstantInt::get(MUIDTy, MUID, false);
304 Constant *CurVersionArg = ConstantInt::get(VersionTy, CurVersion, false);
305 IRBuilder<> IRB(&IP);
306 (void)IRB.CreateCall(ReoptimizeFunc, {MUIDArg, CurVersionArg});
307}
308
309ReOptimizeLayer::ReOptMaterializationUnitState &
310ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
311 std::unique_lock<std::mutex> Lock(Mutex);
312 ReOptMaterializationUnitID MUID = NextID;
313 MUStates.emplace(MUID,
314 ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM)));
315 ++NextID;
316 return MUStates.at(MUID);
317}
318
319ReOptimizeLayer::ReOptMaterializationUnitState &
320ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
321 std::unique_lock<std::mutex> Lock(Mutex);
322 return MUStates.at(MUID);
323}
324
325void ReOptimizeLayer::registerMaterializationUnitResource(
326 ResourceKey Key, ReOptMaterializationUnitState &State) {
327 std::unique_lock<std::mutex> Lock(Mutex);
328 MUResources[Key].insert(State.getID());
329}
330
332 std::unique_lock<std::mutex> Lock(Mutex);
333 for (auto MUID : MUResources[K])
334 MUStates.erase(MUID);
335
336 MUResources.erase(K);
337 return Error::success();
338}
339
341 ResourceKey SrcK) {
342 std::unique_lock<std::mutex> Lock(Mutex);
343 MUResources[DstK].insert_range(MUResources[SrcK]);
344 MUResources.erase(SrcK);
345}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
#define F(x, y, z)
Definition MD5.cpp:54
static void orc_rt_lite_reoptimize_helper(shared::CWrapperFunctionBuffer(*JITDispatch)(void *Ctx, void *Tag, const char *Data, size_t Size), void *JITDispatchCtx, void *Tag, uint64_t MUID, uint32_t CurVersion)
LLVM Basic Block Representation.
Definition BasicBlock.h:62
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
static LLVM_ABI Constant * getIntToPtr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Class to represent function types.
static LLVM_ABI FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition Function.h:166
const Function & getFunction() const
Definition Function.h:164
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
@ ExternalLinkage
Externally visible function.
Definition GlobalValue.h:53
Value * CreateICmpEQ(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2332
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1850
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition IRBuilder.h:1863
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1403
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args={}, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2511
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt64Ty(LLVMContext &C)
Definition Type.cpp:297
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
static LLVM_ABI Type * getVoidTy(LLVMContext &C)
Definition Type.cpp:280
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM Value Representation.
Definition Value.h:75
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:397
ExecutorProcessControl & getExecutorProcessControl()
Get the ExecutorProcessControl object associated with this ExecutionSession.
Definition Core.h:1382
LLVM_ABI void lookup(LookupKind K, const JITDylibSearchOrder &SearchOrder, SymbolLookupSet Symbols, SymbolState RequiredState, SymbolsResolvedCallback NotifyComplete, RegisterDependenciesFunction RegisterDependencies)
Search the given JITDylibs for the given symbols.
Definition Core.cpp:1766
DenseMap< SymbolStringPtr, JITDispatchHandlerFunction > JITDispatchHandlerAssociationMap
A map associating tag names with asynchronous wrapper function implementations in the JIT.
Definition Core.h:1365
const JITDispatchInfo & getJITDispatchInfo() const
Get the JIT dispatch function and context address for the executor.
const IRSymbolMapper::ManglingOptions *& getManglingOptions() const
Get the mangling options for this layer.
Definition Layer.h:79
Represents a JIT'd dynamic library.
Definition Core.h:906
ExecutionSession & getExecutionSession() const
Get a reference to the ExecutionSession for this JITDylib.
Definition Core.h:925
Error registerRuntimeFunctions(JITDylib &PlatformJD)
Registers reoptimize runtime dispatch handlers to given PlatformJD.
ReOptimizeLayer(ExecutionSession &ES, DataLayout &DL, IRLayer &BaseLayer, RedirectableSymbolManager &RM)
void emit(std::unique_ptr< MaterializationResponsibility > R, ThreadSafeModule TSM) override
Emits the given module.
static void createReoptimizeCall(Module &M, Instruction &IP, ReOptMaterializationUnitID MUID, unsigned CurVersion)
void handleTransferResources(JITDylib &JD, ResourceKey DstK, ResourceKey SrcK) override
This function will be called inside the session lock.
Error addOrcRTLiteSupport(JITDylib &PlatformJD, const DataLayout &DL)
Add ORC Runtime-lite support for reoptimization to PlatformJD.
static Error reoptimizeIfCallFrequent(ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID, unsigned CurVersion, ThreadSafeModule &TSM)
Basic AddProfilerFunc that reoptimizes the function when the call count exceeds CallCountThreshold.
Error handleRemoveResources(JITDylib &JD, ResourceKey K) override
This function will be called outside the session lock.
static const uint64_t CallCountThreshold
A set of symbols to look up, each associated with a SymbolLookupFlags value.
Definition Core.h:199
SymbolLookupSet & add(SymbolStringPtr Name, SymbolLookupFlags Flags=SymbolLookupFlags::RequiredSymbol)
Add an element to the set.
Definition Core.h:265
An LLVM Module together with a shared ThreadSafeContext.
decltype(auto) withModuleDo(Func &&F)
Locks the associated ThreadSafeContext and calls the given function on the contained Module.
A utility class for serializing to a blob from a variadic list.
Output char buffer with overflow check.
C++ wrapper function buffer: Same as CWrapperFunctionBuffer but auto-releases memory.
const char * getOutOfBandError() const
If this value is an out-of-band error then this returns the error message, otherwise returns nullptr.
static WrapperFunctionBuffer allocate(size_t Size)
Create a WrapperFunctionBuffer with the given size and return a pointer to the underlying memory.
uintptr_t ResourceKey
Definition Core.h:79
DenseMap< SymbolStringPtr, ExecutorSymbolDef > SymbolMap
A map from symbol names (as SymbolStringPtrs) to JITSymbols (address/flags pairs).
LLVM_ABI ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSMW, GVPredicate ShouldCloneDef=GVPredicate(), GVModifier UpdateClonedDefSource=GVModifier())
Clones the given module on to a new context.
@ Resolved
Queried, materialization begun.
Definition Core.h:780
SmartMutex< false > Mutex
Mutex - A standard, always enforced mutex.
Definition Mutex.h:66
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:296
FunctionAddr VTableAddr uintptr_t uintptr_t Version
Definition InstrProf.h:302
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ Mod
The access may modify the value stored in memory.
Definition ModRef.h:34
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
FunctionAddr VTableAddr uintptr_t uintptr_t Data
Definition InstrProf.h:189
LLVM_ABI Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...