diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -17,8 +17,13 @@ #include "mlir/IR/Location.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" +#include +#include + namespace mlir { namespace spirv { #define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS @@ -61,41 +66,62 @@ loc, llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy)); - // Calculate the 'low' and the 'high' result separately, using long - // multiplication: - // - // lhs = [0 0] [a b] - // rhs = [0 0] [c d] - // --lhs * rhs-- - // = [ a * c ] [ b * d ] + - // [ 0 ] [a * d + b * c] [ 0 ] + // Emulate 64-bit multiplication by splitting each input element of type i32 + // into 2 16-bit digits of type i32. This is so that the intermediate + // multiplications and additions do not overflow. We extract these 16-bit + // digits from i32 vector elements by masking (low digit) and shifting right + // (high digit). // - // ==> high = (a * c) + (a * d + b * c) >> 16 - Value low = rewriter.create(loc, lhs, rhs); - + // The multiplication algorithm used is the standard (long) multiplication. + // Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit + // digits. After constant-folding, we end up emitting only 4 multiplications + // and 4 additions. Value cstLowMask = rewriter.create( loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1)); - auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) { + auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) { return rewriter.create(loc, val, cstLowMask); }; Value cst16 = rewriter.create(loc, lhs.getType(), getScalarOrSplatAttr(argTy, 16)); - auto getHighHalf = [&rewriter, loc, cst16](Value val) { + auto getHighDigit = [&rewriter, loc, cst16](Value val) { return rewriter.create(loc, val, cst16); }; - Value lhsLow = getLowHalf(lhs); - Value lhsHigh = getHighHalf(lhs); - Value rhsLow = getLowHalf(rhs); - Value rhsHigh = getHighHalf(rhs); - - Value high0 = rewriter.create(loc, lhsHigh, rhsHigh); - Value mid = rewriter.create( - loc, rewriter.create(loc, lhsHigh, rhsLow), - rewriter.create(loc, lhsLow, rhsHigh)); - Value high1 = getHighHalf(mid); - Value high = rewriter.create(loc, high0, high1); + Value cst0 = rewriter.create(loc, lhs.getType(), + getScalarOrSplatAttr(argTy, 0)); + + Value lhsLow = getLowDigit(lhs); + Value lhsHigh = getHighDigit(lhs); + Value rhsLow = getLowDigit(rhs); + Value rhsHigh = getHighDigit(rhs); + + std::array lhsDigits = {lhsLow, lhsHigh}; + std::array rhsDigits = {rhsLow, rhsHigh}; + std::array resultDigits = {cst0, cst0, cst0, cst0}; + + for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) { + for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) { + Value &thisResDigit = resultDigits[i + j]; + Value mul = rewriter.create(loc, lhsDigit, rhsDigit); + Value current = rewriter.createOrFold(loc, thisResDigit, mul); + thisResDigit = getLowDigit(current); + + if (i + j + 1 != resultDigits.size()) { + Value &nextResDigit = resultDigits[i + j + 1]; + Value carry = rewriter.createOrFold(loc, nextResDigit, + getHighDigit(current)); + nextResDigit = carry; + } + } + } + + auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) { + Value highBits = rewriter.create(loc, high, cst16); + return rewriter.create(loc, low, highBits); + }; + Value low = combineDigits(resultDigits[0], resultDigits[1]); + Value high = combineDigits(resultDigits[2], resultDigits[3]); rewriter.replaceOpWithNewOp( op, op.getType(), llvm::makeArrayRef({low, high})); @@ -110,6 +136,7 @@ : public impl::SPIRVWebGPUPreparePassBase { public: void runOnOperation() override { + llvm::errs() << __PRETTY_FUNCTION__ << "\n"; RewritePatternSet patterns(&getContext()); populateSPIRVExpandExtendedMultiplicationPatterns(patterns); @@ -127,8 +154,6 @@ RewritePatternSet &patterns) { // WGSL currently does not support extended multiplication ops, see: // https://github.com/gpuweb/gpuweb/issues/1565. - // TODO(https://github.com/llvm/llvm-project/issues/59563): Add SMulExtended - // expansion. patterns.add(patterns.getContext()); } } // namespace spirv diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s +// RUN: mlir-opt --split-input-file --verify-diagnostics \ +// RUN: --spirv-webgpu-prepare --cse %s | FileCheck %s //===----------------------------------------------------------------------===// // spirv.UMulExtended @@ -10,18 +11,23 @@ // CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) // CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32 // CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32 -// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : i32 // CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32 // CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32 // CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32 // CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32 -// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : i32 -// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32 -// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32 -// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : i32 -// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32 -// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32 -// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)> +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]] +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32 +// CHECK: spirv.BitwiseOr +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32 +// CHECK: spirv.BitwiseOr +// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)> // CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)> spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" { %0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)> @@ -32,18 +38,23 @@ // CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>) // CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32> // CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32> -// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32> // CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32> // CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32> // CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32> // CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32> -// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32> -// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32> -// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32> -// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32> -// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32> -// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32> -// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]] +// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]] +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK-DAG: spirv.IAdd +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] +// CHECK: spirv.BitwiseOr +// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] +// CHECK: spirv.BitwiseOr +// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]] // CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)> spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" { diff --git a/mlir/test/mlir-vulkan-runner/umul_extended.mlir b/mlir/test/mlir-vulkan-runner/umul_extended.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/umul_extended.mlir @@ -0,0 +1,66 @@ +// Make sure that unsigned extended multiplication produces expected results +// with and without expansion to primitive mul/add ops for WebGPU. + +// RUN: mlir-vulkan-runner %s \ +// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: --entry-point-result=void | FileCheck %s + +// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \ +// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: --entry-point-result=void | FileCheck %s + +// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, -49] +// CHECK: [0, 0, 1, -2, 0, 65534, -131070, 6] +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + gpu.module @kernels { + gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>) + kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi} { + %0 = gpu.block_id x + %lhs = memref.load %arg0[%0] : memref<8xi32> + %rhs = memref.load %arg1[%0] : memref<8xi32> + %low, %hi = arith.mului_extended %lhs, %rhs : i32 + memref.store %low, %arg2[%0] : memref<8xi32> + memref.store %hi, %arg3[%0] : memref<8xi32> + gpu.return + } + } + + func.func @main() { + %buf0 = memref.alloc() : memref<8xi32> + %buf1 = memref.alloc() : memref<8xi32> + %buf2 = memref.alloc() : memref<8xi32> + %buf3 = memref.alloc() : memref<8xi32> + %i32_0 = arith.constant 0 : i32 + + // Initialize output buffers. + %buf4 = memref.cast %buf2 : memref<8xi32> to memref + %buf5 = memref.cast %buf3 : memref<8xi32> to memref + call @fillResource1DInt(%buf4, %i32_0) : (memref, i32) -> () + call @fillResource1DInt(%buf5, %i32_0) : (memref, i32) -> () + + %idx_0 = arith.constant 0 : index + %idx_1 = arith.constant 1 : index + %idx_8 = arith.constant 8 : index + + // Initialize input buffers. + %lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 7]> : vector<8xi32> + %rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -7]> : vector<8xi32> + vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32> + vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32> + + gpu.launch_func @kernels::@kernel_add + blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1) + args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>) + %buf_low = memref.cast %buf4 : memref to memref<*xi32> + %buf_hi = memref.cast %buf5 : memref to memref<*xi32> + call @printMemrefI32(%buf_low) : (memref<*xi32>) -> () + call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> () + return + } + func.func private @fillResource1DInt(%0 : memref, %1 : i32) + func.func private @printMemrefI32(%ptr : memref<*xi32>) +} diff --git a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt --- a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt +++ b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt @@ -74,6 +74,8 @@ MLIRTargetLLVMIRExport MLIRTransforms MLIRTranslateLib + MLIRVectorDialect + MLIRVectorToLLVM ${Vulkan_LIBRARY} ) diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -13,12 +13,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -30,18 +30,30 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/JitRunner.h" -#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; -static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) { +namespace { +struct VulkanRunnerOptions { + llvm::cl::OptionCategory category{"mlir-vulkan-runner options"}; + llvm::cl::opt spirvWebGPUPrepare{ + "vulkan-runner-spirv-webgpu-prepare", + llvm::cl::desc("Run MLIR transforms used when targetting WebGPU"), + llvm::cl::cat(category)}; +}; +} // namespace + +static LogicalResult runMLIRPasses(Operation *op, + VulkanRunnerOptions &options) { auto module = dyn_cast(op); if (!module) return op->emitOpError("expected a 'builtin.module' op"); @@ -55,10 +67,13 @@ OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createLowerABIAttributesPass()); modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); + if (options.spirvWebGPUPrepare) + modulePM.addPass(spirv::createSPIRVWebGPUPreparePass()); passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module)); passManager.addPass(createMemRefToLLVMConversionPass()); + passManager.addPass(createConvertVectorToLLVMPass()); passManager.nest().addPass(LLVM::createRequestCWrappersPass()); passManager.addPass(createConvertFuncToLLVMPass(llvmOptions)); passManager.addPass(createReconcileUnrealizedCastsPass()); @@ -75,13 +90,21 @@ llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + // Initialize runner-specific CLI options. These will be parsed and + // initialzied in `JitRunnerMain`. + VulkanRunnerOptions options; + auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) { + return runMLIRPasses(op, options); + }; + mlir::JitRunnerConfig jitRunnerConfig; - jitRunnerConfig.mlirTransformer = runMLIRPasses; + jitRunnerConfig.mlirTransformer = runPassesWithOptions; mlir::DialectRegistry registry; registry.insert(); + mlir::func::FuncDialect, mlir::memref::MemRefDialect, + mlir::vector::VectorDialect>(); mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7303,6 +7303,8 @@ ":SPIRVDialect", ":SPIRVTransforms", ":ToLLVMIRTranslation", + ":VectorDialect", + ":VectorToLLVM", "//llvm:Support", ], )