diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -203,6 +203,12 @@ bool enableVLAVectorization, bool enableSIMDIndex32); +void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, + unsigned numThreads); + +std::unique_ptr createSparseGPUCodegenPass(); +std::unique_ptr createSparseGPUCodegenPass(unsigned numThreads); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -310,6 +310,26 @@ ]; } +def SparseGPUCodegen : Pass<"sparse-gpu-codegen", "ModuleOp"> { + let summary = "Generates GPU code during sparsification"; + let description = [{ + Enables sparse compiler to use GPU acceleration. + }]; + let constructor = "mlir::createSparseGPUCodegenPass()"; + let dependentDialects = [ + "arith::ArithDialect", + "bufferization::BufferizationDialect", + "gpu::GPUDialect", + "linalg::LinalgDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + ]; + let options = [ + Option<"numThreads", "num_threads", "int32_t", "1024", "Sets the number of GPU threads">, + ]; +} + def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> { let summary = "Lower sparse storage specifer to llvm structure"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ CodegenUtils.cpp LoopEmitter.cpp SparseBufferRewriting.cpp + SparseGPUCodegen.cpp SparseStorageSpecifierToLLVM.cpp SparseTensorCodegen.cpp SparseTensorConversion.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -0,0 +1,247 @@ +//===- SparseGPUCodegen.cpp - Generates GPU code (using CUDA) -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a prototype GPU codegenerator for the sparse compiler. +// The objective is to eventually use the right combination of +// direct code generation and libary calls into vendor-specific +// highly optimized sparse libraries (e.g. cuSparse for CUDA). +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" +#include "LoopEmitter.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + +/// Marks the given top module as a GPU container module. +static void markAsGPUContainer(ModuleOp topModule) { + topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(topModule->getContext())); +} + +/// Constructs a new GPU module (for GPU kernels) inside the given top module. +static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule, + StringRef name) { + markAsGPUContainer(topModule); + builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); + return builder.create(topModule->getLoc(), name); +} + +/// Constructs a new GPU kernel in the given GPU module. +static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, + StringRef name, SmallVectorImpl &args) { + builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); + SmallVector argsTp; + for (unsigned i = 0, e = args.size(); i < e; i++) + argsTp.push_back(args[i].getType()); + FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); + auto gpuFunc = + builder.create(gpuModule->getLoc(), name, type); + gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + return gpuFunc; +} + +/// Constructs code to launch GPU kernel. +static void genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, + SmallVectorImpl &args, + unsigned numThreads) { + Location loc = gpuFunc->getLoc(); + Value none = TypedValue<::mlir::IntegerType>{}; + Value one = constantIndex(builder, loc, 1); + Value numT = constantIndex(builder, loc, numThreads); + gpu::KernelDim3 gridSize = {one, one, one}; + gpu::KernelDim3 blckSize = {numT, one, one}; + builder.create(loc, gpuFunc, gridSize, blckSize, + /*dynSharedMemSz*/ none, args); +} + +/// Maps the provided ranked host buffer into the device address space. +/// Writes from the host are guaranteed to be visible to device kernels +/// that are launched afterwards. Writes from the device are guaranteed +/// to be visible on the host after synchronizing with the device kernel +/// completion. +static Value genHostRegisterMemref(OpBuilder &builder, Location loc, + Value mem) { + MemRefType memTp = mem.getType().cast(); + UnrankedMemRefType resTp = + UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); + Value cast = builder.create(loc, resTp, mem); + builder.create(loc, cast); + return mem; // convenience pass-through +} + +/// Constructs code for new GPU kernel. +static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, + scf::ParallelOp forallOp, + SmallVectorImpl &constants, + SmallVectorImpl &scalars, + SmallVectorImpl &buffers) { + Location loc = gpuFunc->getLoc(); + Block &block = gpuFunc.getBody().front(); + rewriter.setInsertionPointToStart(&block); + + // Re-generate the constants, recapture all arguments. + unsigned arg = 0; + IRMapping irMap; + for (Value c : constants) + irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); + for (Value s : scalars) + irMap.map(s, block.getArgument(arg++)); + for (Value b : buffers) + irMap.map(b, block.getArgument(arg++)); + + // Assume 1-dimensional grid/block configuration (only x dimension), + // so that: + // row = blockIdx.x * blockDim.x + threadIdx.x + // inc = blockDim.x * gridDim.x + Value bid = rewriter.create(loc, gpu::Dimension::x); + Value bsz = rewriter.create(loc, gpu::Dimension::x); + Value tid = rewriter.create(loc, gpu::Dimension::x); + Value gsz = rewriter.create(loc, gpu::Dimension::x); + Value mul = rewriter.create(loc, bid, bsz); + Value row = rewriter.create(loc, mul, tid); + Value inc = rewriter.create(loc, bsz, gsz); + + // Construct the iteration over the computational space that + // accounts for the fact that the total number of threads and + // the amount of work to be done usually do not match precisely. + // for (r = row; r < N; r += inc) { + // + // } + Value upper = irMap.lookup(forallOp.getUpperBound()[0]); + scf::ForOp forOp = rewriter.create(loc, row, upper, inc); + rewriter.cloneRegionBefore(forallOp.getLoopBody(), forOp.getLoopBody(), + forOp.getLoopBody().begin(), irMap); + + // Done. + rewriter.setInsertionPointAfter(forOp); + rewriter.create(gpuFunc->getLoc()); +} + +//===----------------------------------------------------------------------===// +// Rewriting rules. +//===----------------------------------------------------------------------===// + +/// Proof-of-concept rewriter. This rule generates a CUDA implementation +/// for each outermost forall loop generated by the sparse compiler. +// +// TODO: right works with parallelization-strategy=dense-outer-loop +// but give this its own flags in the future +// +struct ForallRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ForallRewriter(MLIRContext *context, unsigned nT) + : OpRewritePattern(context), numThreads(nT){}; + + LogicalResult matchAndRewrite(scf::ParallelOp forallOp, + PatternRewriter &rewriter) const override { + // Reject inadmissible loop form. + // Essentially only accept a loop, generated by the sparse compiler, + // of the form + // forall (i = 0; i < N; i++) + // so that cyclic scheduling over the threads is easy. + if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || + forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || + !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || + !matchPattern(forallOp.getStep()[0], m_One())) + return failure(); + // Collect every value that is computed outside the parallel loop. + SetVector invariants; // stable iteration! + forallOp->walk([&](Operation *op) { + // Collect all values of admissible ops. + for (OpOperand &o : op->getOpOperands()) { + Value val = o.get(); + Block *block; + if (auto arg = val.dyn_cast()) + block = arg.getOwner(); + else + block = val.getDefiningOp()->getBlock(); + if (!isNestedIn(block, forallOp)) + invariants.insert(val); + } + }); + // Outline the outside values as proper parameters. Fail when sharing + // value between host and device is not straightforward. + SmallVector constants; + SmallVector scalars; + SmallVector buffers; + for (Value val : invariants) { + Type tp = val.getType(); + if (val.getDefiningOp()) + constants.push_back(val); + else if (tp.isa() || tp.isIntOrIndex()) + scalars.push_back(val); + else if (isa(tp)) + buffers.push_back(val); + else + return failure(); // don't know how to share + } + // Prepare the outlined arguments, register buffers. + Location loc = forallOp->getLoc(); + SmallVector args; + for (Value s : scalars) + args.push_back(s); + for (Value b : buffers) + args.push_back(genHostRegisterMemref(rewriter, loc, b)); + auto saveIp = rewriter.saveInsertionPoint(); + // Set up GPU module and construct GPU function. + // + // TODO: only generate once, avoid name conflict + // + ModuleOp topModule = forallOp->getParentOfType(); + auto gpuModule = genGPUModule(rewriter, topModule, "sparsekernels"); + auto gpuFunc = genGPUFunc(rewriter, gpuModule, "kernel", args); + genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); + // Generate code that launches the kernel. + rewriter.restoreInsertionPoint(saveIp); + genLaunchGPUFunc(rewriter, gpuFunc, args, numThreads); + rewriter.eraseOp(forallOp); + return success(); + } + +private: + // Helper method to see if block appears in given loop. + static bool isNestedIn(Block *block, scf::ParallelOp forallOp) { + for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) { + if (o == forallOp) + return true; + } + return false; + } + + unsigned numThreads; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public method for populating GPU rewriting rules. +//===----------------------------------------------------------------------===// + +void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, + unsigned numThreads) { + patterns.add(patterns.getContext(), numThreads); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" @@ -28,6 +29,7 @@ #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE #define GEN_PASS_DEF_SPARSEVECTORIZATION +#define GEN_PASS_DEF_SPARSEGPUCODEGEN #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -281,6 +283,21 @@ } }; +struct SparseGPUCodegenPass + : public impl::SparseGPUCodegenBase { + + SparseGPUCodegenPass() = default; + SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default; + SparseGPUCodegenPass(unsigned nT) { numThreads = nT; } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseGPUCodegenPatterns(patterns, numThreads); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct StorageSpecifierToLLVMPass : public impl::StorageSpecifierToLLVMBase { @@ -406,6 +423,14 @@ vectorLength, enableVLAVectorization, enableSIMDIndex32); } +std::unique_ptr mlir::createSparseGPUCodegenPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::createSparseGPUCodegenPass(unsigned numThreads) { + return std::make_unique(numThreads); +} + std::unique_ptr mlir::createStorageSpecifierToLLVMPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --pre-sparsification-rewrite \ +// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \ +// RUN: --sparse-gpu-codegen | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> + +// +// Compute matrix matrix C = AB +// +// CHECK-LABEL: gpu.func @kernel( +// CHECK-SAME: %[[VAL_0:.*0]]: index, +// CHECK-SAME: %[[VAL_1:.*1]]: index, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref, +// CHECK-SAME: %[[VAL_6:.*6]]: memref) kernel { +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = gpu.block_id x +// CHECK: %[[VAL_10:.*]] = gpu.block_dim x +// CHECK: %[[VAL_11:.*]] = gpu.thread_id x +// CHECK: %[[VAL_12:.*]] = gpu.grid_dim x +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_10]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_11]] : index +// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_10]], %[[VAL_12]] : index +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_1]] step %[[VAL_15]] { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_7]] : index +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref +// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_7]] { +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_20]]] : memref +// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_0]] step %[[VAL_7]] { +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_21]], %[[VAL_23]]] : memref +// CHECK: %[[VAL_26:.*]] = arith.mulf %[[VAL_22]], %[[VAL_25]] : f64 +// CHECK: %[[VAL_27:.*]] = arith.addf %[[VAL_24]], %[[VAL_26]] : f64 +// CHECK: memref.store %[[VAL_27]], %[[VAL_5]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } +// CHECK: gpu.return +// CHECK: } +// +// +// CHECK-LABEL: func.func @matmul +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.launch_func @sparsekernels::@kernel blocks +// +func.func @matmul(%A: tensor, %B: tensor, %C_in: tensor) -> tensor { + %C_out = linalg.matmul + ins(%A, %B: tensor, tensor) + outs(%C_in: tensor) -> tensor + return %C_out : tensor +} diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --pre-sparsification-rewrite \ +// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \ +// RUN: --sparse-gpu-codegen | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> + +// +// Compute matrix vector y = Ax +// +// +// CHECK: gpu.func @kernel( +// CHECK-SAME: %[[VAL_0:.*0]]: index, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref) kernel { +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = gpu.block_id x +// CHECK: %[[VAL_8:.*]] = gpu.block_dim x +// CHECK: %[[VAL_9:.*]] = gpu.thread_id x +// CHECK: %[[VAL_10:.*]] = gpu.grid_dim x +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_7]], %[[VAL_8]] : index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_9]] : index +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_0]] step %[[VAL_13]] { +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : index +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_6]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f64) { +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f64 +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_21]], %[[VAL_25]] : f64 +// CHECK: scf.yield %[[VAL_26]] : f64 +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: memref.store %[[VAL_27:.*]], %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref +// CHECK: } +// CHECK: gpu.return +// CHECK: } +// +// CHECK-LABEL: func.func @matvec +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.host_register +// CHECK: gpu.launch_func @sparsekernels::@kernel blocks +// +func.func @matvec(%A: tensor, %x: tensor, %y_in: tensor) -> tensor { + %y_out = linalg.matvec + ins(%A, %x: tensor, tensor) + outs(%y_in: tensor) -> tensor + return %y_out : tensor +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2343,6 +2343,7 @@ ":DialectUtils", ":FuncDialect", ":FuncTransforms", + ":GPUDialect", ":IR", ":LLVMCommonConversion", ":LLVMDialect",