LLVM 23.0.0git
SPIRVLegalizeImplicitBinding.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding
11// intrinsic by replacing it with a call to
12// @llvm.spv.resource.handlefrombinding.
13//
14//===----------------------------------------------------------------------===//
15
17#include "SPIRV.h"
18#include "llvm/ADT/BitVector.h"
19#include "llvm/ADT/STLExtras.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/InstVisitor.h"
23#include "llvm/IR/Intrinsics.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/IR/Module.h"
26#include "llvm/Pass.h"
27#include <algorithm>
28#include <vector>
29
30using namespace llvm;
31
32namespace {
33class SPIRVLegalizeImplicitBindingImpl {
34public:
35 bool runOnModule(Module &M);
36
37private:
38 void collectBindingInfo(Module &M);
39 uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
40 void replaceImplicitBindingCalls(Module &M);
41 void replaceResourceHandleCall(Module &M, CallInst *OldCI,
42 uint32_t NewBinding);
43 void replaceCounterHandleCall(Module &M, CallInst *OldCI,
44 uint32_t NewBinding);
45 void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls);
46
47 // A map from descriptor set to a bit vector of used binding numbers.
48 std::vector<BitVector> UsedBindings;
49 // A list of all implicit binding calls, to be sorted by order ID.
50 SmallVector<CallInst *, 16> ImplicitBindingCalls;
51};
52
53class SPIRVLegalizeImplicitBindingLegacy : public ModulePass {
54public:
55 static char ID;
56 SPIRVLegalizeImplicitBindingLegacy() : ModulePass(ID) {}
57 StringRef getPassName() const override {
58 return "SPIRV Legalize Implicit Binding";
59 }
60 bool runOnModule(Module &M) override {
61 return SPIRVLegalizeImplicitBindingImpl().runOnModule(M);
62 }
63};
64
65struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
66 std::vector<BitVector> &UsedBindings;
67 SmallVector<CallInst *, 16> &ImplicitBindingCalls;
68
69 BindingInfoCollector(std::vector<BitVector> &UsedBindings,
70 SmallVector<CallInst *, 16> &ImplicitBindingCalls)
71 : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
72 }
73
74 void addBinding(uint32_t DescSet, uint32_t Binding) {
75 if (UsedBindings.size() <= DescSet) {
76 UsedBindings.resize(DescSet + 1);
77 UsedBindings[DescSet].resize(64);
78 }
79 if (UsedBindings[DescSet].size() <= Binding) {
80 UsedBindings[DescSet].resize(2 * Binding + 1);
81 }
82 UsedBindings[DescSet].set(Binding);
83 }
84
85 void visitCallInst(CallInst &CI) {
86 if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
87 const uint32_t DescSet =
88 cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
89 const uint32_t Binding =
90 cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
91 addBinding(DescSet, Binding);
92 } else if (CI.getIntrinsicID() ==
93 Intrinsic::spv_resource_handlefromimplicitbinding) {
94 ImplicitBindingCalls.push_back(&CI);
95 } else if (CI.getIntrinsicID() ==
96 Intrinsic::spv_resource_counterhandlefrombinding) {
97 const uint32_t DescSet =
98 cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue();
99 const uint32_t Binding =
100 cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
101 addBinding(DescSet, Binding);
102 } else if (CI.getIntrinsicID() ==
103 Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
104 ImplicitBindingCalls.push_back(&CI);
105 }
106 }
107};
108
109static uint32_t getOrderId(const CallInst *CI) {
110 uint32_t OrderIdArgIdx = 0;
111 switch (CI->getIntrinsicID()) {
112 case Intrinsic::spv_resource_handlefromimplicitbinding:
113 OrderIdArgIdx = 0;
114 break;
115 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
116 OrderIdArgIdx = 1;
117 break;
118 default:
119 llvm_unreachable("CallInst is not an implicit binding intrinsic");
120 }
121 return cast<ConstantInt>(CI->getArgOperand(OrderIdArgIdx))->getZExtValue();
122}
123
124static uint32_t getDescSet(const CallInst *CI) {
125 uint32_t DescSetArgIdx;
126 switch (CI->getIntrinsicID()) {
127 case Intrinsic::spv_resource_handlefromimplicitbinding:
128 case Intrinsic::spv_resource_handlefrombinding:
129 DescSetArgIdx = 1;
130 break;
131 case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
132 case Intrinsic::spv_resource_counterhandlefrombinding:
133 DescSetArgIdx = 2;
134 break;
135 default:
136 llvm_unreachable("CallInst is not an implicit binding intrinsic");
137 }
138 return cast<ConstantInt>(CI->getArgOperand(DescSetArgIdx))->getZExtValue();
139}
140
141void SPIRVLegalizeImplicitBindingImpl::collectBindingInfo(Module &M) {
142 BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
143 InfoCollector.visit(M);
144
145 // Sort the collected calls by their order ID.
146 llvm::sort(ImplicitBindingCalls, [](const CallInst *A, const CallInst *B) {
147 return getOrderId(A) < getOrderId(B);
148 });
149}
150
151void SPIRVLegalizeImplicitBindingImpl::verifyUniqueOrderIdPerResource(
152 SmallVectorImpl<CallInst *> &Calls) {
153 // Check that the order Id is unique per resource.
154 for (uint32_t i = 1; i < Calls.size(); ++i) {
155 const uint32_t OrderA = getOrderId(Calls[i - 1]);
156 const uint32_t OrderB = getOrderId(Calls[i]);
157 if (OrderA == OrderB) {
158 const uint32_t DescSetA = getDescSet(Calls[i - 1]);
159 const uint32_t DescSetB = getDescSet(Calls[i]);
160 if (DescSetA != DescSetB) {
161 report_fatal_error("Implicit binding calls with the same order ID must "
162 "have the same descriptor set");
163 }
164 }
165 }
166}
167
168uint32_t SPIRVLegalizeImplicitBindingImpl::getAndReserveFirstUnusedBinding(
169 uint32_t DescSet) {
170 if (UsedBindings.size() <= DescSet) {
171 UsedBindings.resize(DescSet + 1);
172 UsedBindings[DescSet].resize(64);
173 }
174
175 int NewBinding = UsedBindings[DescSet].find_first_unset();
176 if (NewBinding == -1) {
177 NewBinding = UsedBindings[DescSet].size();
178 UsedBindings[DescSet].resize(2 * NewBinding + 1);
179 }
180
181 UsedBindings[DescSet].set(NewBinding);
182 return NewBinding;
183}
184
185void SPIRVLegalizeImplicitBindingImpl::replaceImplicitBindingCalls(Module &M) {
186 uint32_t lastOrderId = -1;
187 uint32_t lastBindingNumber = -1;
188
189 for (CallInst *OldCI : ImplicitBindingCalls) {
190 const uint32_t OrderId = getOrderId(OldCI);
191 uint32_t BindingNumber;
192 if (OrderId == lastOrderId) {
193 BindingNumber = lastBindingNumber;
194 } else {
195 const uint32_t DescSet = getDescSet(OldCI);
196 BindingNumber = getAndReserveFirstUnusedBinding(DescSet);
197 }
198
199 if (OldCI->getIntrinsicID() ==
200 Intrinsic::spv_resource_handlefromimplicitbinding) {
201 replaceResourceHandleCall(M, OldCI, BindingNumber);
202 } else {
203 assert(OldCI->getIntrinsicID() ==
204 Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
205 "Unexpected implicit binding intrinsic");
206 replaceCounterHandleCall(M, OldCI, BindingNumber);
207 }
208 lastOrderId = OrderId;
209 lastBindingNumber = BindingNumber;
210 }
211}
212
213bool SPIRVLegalizeImplicitBindingImpl::runOnModule(Module &M) {
214 collectBindingInfo(M);
215 if (ImplicitBindingCalls.empty()) {
216 return false;
217 }
218 verifyUniqueOrderIdPerResource(ImplicitBindingCalls);
219
220 replaceImplicitBindingCalls(M);
221 return true;
222}
223} // namespace
224
227 return SPIRVLegalizeImplicitBindingImpl().runOnModule(M)
230}
231
232char SPIRVLegalizeImplicitBindingLegacy::ID = 0;
233
234INITIALIZE_PASS(SPIRVLegalizeImplicitBindingLegacy,
235 "legalize-spirv-implicit-binding",
236 "Legalize SPIR-V implicit bindings", false, false)
237
239 return new SPIRVLegalizeImplicitBindingLegacy();
240}
241
242void SPIRVLegalizeImplicitBindingImpl::replaceResourceHandleCall(
243 Module &M, CallInst *OldCI, uint32_t NewBinding) {
244 IRBuilder<> Builder(OldCI);
245 const uint32_t DescSet =
246 cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
247
248 SmallVector<Value *, 8> Args;
249 Args.push_back(Builder.getInt32(DescSet));
250 Args.push_back(Builder.getInt32(NewBinding));
251
252 // Copy the remaining arguments from the old call.
253 for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
254 Args.push_back(OldCI->getArgOperand(i));
255 }
256
258 &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
259 CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
260 NewCI->setCallingConv(OldCI->getCallingConv());
261
262 OldCI->replaceAllUsesWith(NewCI);
263 OldCI->eraseFromParent();
264}
265
266void SPIRVLegalizeImplicitBindingImpl::replaceCounterHandleCall(
267 Module &M, CallInst *OldCI, uint32_t NewBinding) {
268 IRBuilder<> Builder(OldCI);
269 const uint32_t DescSet =
270 cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue();
271
272 SmallVector<Value *, 8> Args;
273 Args.push_back(OldCI->getArgOperand(0));
274 Args.push_back(Builder.getInt32(NewBinding));
275 Args.push_back(Builder.getInt32(DescSet));
276
277 Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()};
279 &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys);
280 CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
281 NewCI->setCallingConv(OldCI->getCallingConv());
282
283 OldCI->replaceAllUsesWith(NewCI);
284 OldCI->eraseFromParent();
285}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements the BitVector class.
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
DXIL Resource Implicit Binding
Module.h This file contains the declarations for the Module class.
Machine Check Debug Module
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
void setCallingConv(CallingConv::ID CC)
CallingConv::ID getCallingConv() const
Value * getArgOperand(unsigned i) const
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
unsigned arg_size() const
This class represents a function call, abstracting a target machine's calling convention.
Base class for instruction visitors.
Definition InstVisitor.h:78
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
LLVM_ABI Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > OverloadTys={})
Look up the Function declaration of the intrinsic id in the Module M.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1636
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
ModulePass * createSPIRVLegalizeImplicitBindingPass()