diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -233,6 +233,15 @@ [](MLIRContext *context) { context->loadDialect(); }); } + /// Declares that the transformations associated with the operations + /// registered by this dialect extension need to register additional + /// extensions, beyond just dialects. This is used in particular for + /// registering translations that need to be called during IR transformation + /// (e.g. generating embedded binary blobs). + void declareRegistration(std::function fun) { + generatedDialectLoaders.push_back(fun); + } + private: /// Callbacks performing extension initialization, e.g., registering ops, /// types and defining the additional data. diff --git a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRGPUDeviceMapperEnumsGen LINK_LIBS PUBLIC + MLIRDLTIDialect MLIRGPUDialect MLIRGPUTransforms MLIRIR @@ -24,4 +25,10 @@ # ConversionPatterns MLIRNVGPUToNVVM MLIRGPUToNVVMTransforms + + # Translations (needed to serialize to cubin) + MLIRNVVMToLLVMIRTranslation + MLIRGPUToLLVMIRTranslation + MLIRLLVMIRToLLVMTranslation + MLIRLLVMToLLVMIRTranslation ) diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/Utils.h" @@ -29,11 +30,16 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -1455,10 +1461,19 @@ declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" >(); + // Register translations (needed to serialize to cubin). + declareRegistration([](MLIRContext *c) { + registerNVVMDialectTranslation(*c); + registerGPUDialectTranslation(*c); + registerLLVMDialectTranslation(*c); + }); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -19,6 +19,7 @@ MLIRBufferizationDialect MLIRBufferizationTransforms MLIRFuncDialect + MLIRIndexDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" @@ -36,6 +37,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -709,6 +709,7 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { + llvm::errs() << "RUNNING PASS!! " << getPassName() << "\n"; // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so when applying passes because they may perform a wide range of IR diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir --- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir @@ -1,27 +1,19 @@ // RUN: mlir-opt %s \ -// RUN: -test-transform-dialect-interpreter \ -// RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC - -// CHECK-MMA-SYNC-LABEL: func @main() { -// CHECK-MMA-SYNC: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled} -// CHECK-MMA-SYNC-SAME: : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> - -// Tested to run locally in 1.7s. - -// RUN: mlir-opt %s \ -// RUN: -test-transform-dialect-interpreter \ +// RUN: -test-transform-dialect-interpreter=debug-payload-root-tag="payload" -debug \ // RUN: -test-transform-dialect-erase-schedule \ -// RUN: -test-lower-to-nvvm="kernel-index-bitwidth=32 cubin-chip=sm_80 cubin-features=+ptx76" \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ // RUN: --entry-point-result=void \ // RUN: | FileCheck %s + !lhs_memref_type = memref<16x4xf32> !rhs_memref_type = memref<4x8xf32> !res_memref_type = memref<16x8xf32> +module attributes {transform.target_tag="payload"} { + func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f32 { %r = arith.index_cast %ridx : index to i32 %c = arith.index_cast %cidx : index to i32 @@ -154,10 +146,135 @@ func.func private @printMemrefF32(memref<*xf32>) +} // module + + +/// Schedule to lower device GPU IR and host IR to LLVM. +/// In the future this should be preloaded from a separate file. +module @named_inclusion_in_named attributes { transform.with_named_sequence } { + +// Spell out lowering to NVVM to make it less bespoke and more easily configurable. +transform.named_sequence @lower_gpu( + %module: !transform.any_op {transform.consumed}) -> !transform.any_op { + + %m2 = transform.apply_registered_pass "gpu-kernel-outlining" to %module : (!transform.any_op) -> !transform.any_op + + %gpu_module = transform.structured.match ops{["gpu.module"]} in %m2 : (!transform.any_op) -> !transform.any_op + %gm2 = transform.apply_registered_pass "convert-vector-to-scf" to %gpu_module : (!transform.any_op) -> !transform.any_op + %gm3 = transform.apply_registered_pass "convert-scf-to-cf" to %gm2 : (!transform.any_op) -> !transform.any_op + %gm4 = transform.apply_registered_pass "expand-strided-metadata" to %gm3 : (!transform.any_op) -> !transform.any_op + %gm5 = transform.apply_registered_pass "lower-affine" to %gm4 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %gm5 { + transform.apply_conversion_patterns.dialect_to_llvm "math" + transform.apply_conversion_patterns.dialect_to_llvm "memref" + transform.apply_conversion_patterns.func.func_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "index" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {index_bitwidth = 32, + use_bare_ptr = true, + use_bare_ptr_memref_call_conv = true, + use_opaque_pointers = true} + } { + legal_dialects = ["llvm", "gpu", "nvvm"], + partial_conversion + } : !transform.any_op + + // apply_conversion_patterns loses track of handles so we rematch. + %gpu_module2 = transform.structured.match ops{["gpu.module"]} in %m2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %gpu_module2 { + transform.apply_patterns.gpu.gpu_rewrite_patterns + } : !transform.any_op + + // apply_conversion_patterns loses track of handles so we rematch. + %gpu_module3 = transform.structured.match ops{["gpu.module"]} in %m2 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %gpu_module3 { + transform.apply_conversion_patterns.dialect_to_llvm "arith" + transform.apply_conversion_patterns.dialect_to_llvm "cf" + transform.apply_conversion_patterns.vector.vector_to_llvm + transform.apply_conversion_patterns.func.func_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "memref" + transform.apply_conversion_patterns.gpu.gpu_to_nvvm + transform.apply_conversion_patterns.gpu.gpu_wmma_to_nvvm + transform.apply_conversion_patterns.gpu.gpu_subgroup_reduce_to_nvvm + transform.apply_conversion_patterns.nvgpu.nvgpu_to_nvvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {index_bitwidth = 32, + use_bare_ptr = true, + use_bare_ptr_memref_call_conv = true, + use_opaque_pointers = true} + } { + legal_dialects = ["llvm", "memref", "nvvm"], + legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"], + illegal_dialects = ["gpu"], + illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil", + "llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2","llvm.pow", + "llvm.sin", "llvm.sqrt"], + partial_conversion + } : !transform.any_op + + // apply_conversion_patterns loses track of handles so we rematch. + %gpu_module4 = transform.structured.match ops{["gpu.module"]} in %m2 : (!transform.any_op) -> !transform.any_op + %mm2 = transform.apply_registered_pass "convert-vector-to-llvm" to %gpu_module4 : (!transform.any_op) -> !transform.any_op + %mm3 = transform.apply_registered_pass "canonicalize" to %mm2 : (!transform.any_op) -> !transform.any_op + %mm4 = transform.apply_registered_pass "cse" to %mm3 : (!transform.any_op) -> !transform.any_op + %mm5 = transform.apply_registered_pass "reconcile-unrealized-casts" to %mm4 : (!transform.any_op) -> !transform.any_op + %mm6 = transform.apply_registered_pass "gpu-to-cubin" to %mm5 {options="chip=sm_80 features=+ptx76"} : (!transform.any_op) -> !transform.any_op + + transform.yield %m2 : !transform.any_op +} + +transform.named_sequence @lower_host( + %module: !transform.any_op {transform.consumed}) -> !transform.any_op { + %m3 = transform.apply_registered_pass "convert-vector-to-scf" to %module : (!transform.any_op) -> !transform.any_op + %m4 = transform.apply_registered_pass "convert-scf-to-cf" to %m3 : (!transform.any_op) -> !transform.any_op + %m5 = transform.apply_registered_pass "expand-strided-metadata" to %m4 : (!transform.any_op) -> !transform.any_op + %m6 = transform.apply_registered_pass "lower-affine" to %m5 : (!transform.any_op) -> !transform.any_op + + // TODO: apply_conversion_patterns loses track of handles so we only apply it to func.func ops. + %func = transform.structured.match ops{["func.func"]} in %m6 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %func { + transform.apply_conversion_patterns.dialect_to_llvm "math" + transform.apply_conversion_patterns.vector.vector_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "memref" + transform.apply_conversion_patterns.func.func_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "index" + transform.apply_conversion_patterns.dialect_to_llvm "arith" + transform.apply_conversion_patterns.dialect_to_llvm "cf" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {index_bitwidth = 64, + use_bare_ptr = true, + use_bare_ptr_memref_call_conv = true, + use_opaque_pointers = true} + } { + legal_dialects = ["llvm", "nvvm"], + legal_ops = ["builtin.module", "gpu.module", "gpu.module_end", "gpu.yield"], + partial_conversion + } : !transform.any_op + + %m7 = transform.apply_registered_pass "gpu-to-llvm" to %m6 + : (!transform.any_op) -> !transform.any_op + %m8 = transform.apply_registered_pass "reconcile-unrealized-casts" to %m7 + : (!transform.any_op) -> !transform.any_op + + transform.yield %m8 : !transform.any_op +} + + transform.sequence failures(propagate) { -^bb1(%arg1: !transform.any_op): - %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 +^bb1(%toplevel_module: !transform.any_op): + + %matmul = transform.structured.match ops{["linalg.matmul"]} in %toplevel_module : (!transform.any_op) -> !transform.any_op transform.nvgpu.rewrite_matmul_as_mma_sync %matmul : (!transform.any_op) -> () + + %m2 = transform.include @lower_gpu failures(suppress) (%toplevel_module) + : (!transform.any_op) -> (!transform.any_op) + %m3 = transform.include @lower_host failures(suppress) (%m2) + : (!transform.any_op) -> (!transform.any_op) } + +} // transform module diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5007,17 +5007,21 @@ ":ArithDialect", ":AsmParser", ":ControlFlowDialect", + ":DLTIDialect", ":DialectUtils", ":FuncDialect", ":GPUCommonTransforms", ":GPUDialect", + ":GPUToLLVMIRTranslation", ":GPUToNVVMTransforms", ":GPUTransformOpsIncGen", ":GPUTransforms", ":IR", ":LLVMCommonConversion", + ":LLVMToLLVMIRTranslation", ":MemRefDialect", ":NVVMDialect", + ":NVVMToLLVMIRTranslation", ":Parser", ":SCFDialect", ":SideEffectInterfaces", @@ -9751,6 +9755,7 @@ ":FuncDialect", ":GPUDialect", ":IR", + ":IndexDialect", ":LinalgDialect", ":LinalgMatchOpsIncGen", ":LinalgTransformEnumsIncGen",