diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp overloadedResults, - list overloadedOperands, list traits, +class NVVM_IntrOp traits, int numResults> : LLVM_IntrOpBase; + /*list overloadedResults=*/[], + /*list overloadedOperands=*/[], + traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp traits = []> : - NVVM_IntrOp, - Arguments<(ins)> { +class NVVM_SpecialRegisterOp traits = []> : + NVVM_IntrOp { + let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,6 +92,16 @@ def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM approximate op definitions +//===----------------------------------------------------------------------===// + +def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { + let arguments = (ins F32:$arg); + let results = (outs F32:$res); + let assemblyFormat = "$arg attr-dict `:` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h @@ -0,0 +1,25 @@ +//===- OptimizeNVVM.h - Prepare for translation to LLVM IR -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H + +#include + +namespace mlir { +class Pass; + +namespace NVVM { + +/// Creates a pass that optimizes LLVM IR for the NVVM target. +std::unique_ptr createOptimizeForTargetPass(); + +} // namespace NVVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" +#include "mlir/Dialect/LLVMIR/Transforms/OptimizeNVVM.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::LLVM::createLegalizeForExportPass()"; } +def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { + let summary = "Optimize NVVM IR"; + let constructor = "mlir::NVVM::createOptimizeForTargetPass()"; +} + #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms LegalizeForExport.cpp + OptimizeForNVVM.cpp DEPENDS MLIRLLVMPassIncGen diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -0,0 +1,100 @@ +//===- OptimizeNVVM.cpp - Optimize NVVM IR ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/OptimizeNVVM.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one +// (conditional) Newton iteration. +// +// This as accurate as promoting the division to fp32 in the NVPTX backend, but +// faster because it performs less Newton iterations, avoids the slow path +// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions +// by the same divisor. +struct ExpandDivF16 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +private: + LogicalResult matchAndRewrite(LLVM::FDivOp op, + PatternRewriter &rewriter) const override; +}; + +struct NVVMOptimizeForTarget + : public NVVMOptimizeForTargetBase { + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, + + PatternRewriter &rewriter) const { + if (!op.getType().isF16()) + return rewriter.notifyMatchFailure(op, "not f16"); + Location loc = op.getLoc(); + + Type f32Type = rewriter.getF32Type(); + Type i32Type = rewriter.getI32Type(); + + // Extend lhs and rhs to fp32. + Value lhs = rewriter.create(loc, f32Type, op.getLhs()); + Value rhs = rewriter.create(loc, f32Type, op.getRhs()); + + // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. + Value rcp = rewriter.create(loc, f32Type, rhs); + Value approx = rewriter.create(loc, lhs, rcp); + + // Refine the approximation with one Newton iteration: + // float refined = approx + (lhs - approx * rhs) * rcp; + Value err = rewriter.create( + loc, approx, rewriter.create(loc, rhs), lhs); + Value refined = rewriter.create(loc, err, rcp, approx); + + // Use refined value if approx is normal (exponent neither all 0 or all 1). + Value mask = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = rewriter.create(loc, i32Type, approx); + Value exp = rewriter.create(loc, i32Type, cast, mask); + Value zero = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0)); + Value pred = rewriter.create( + loc, + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value result = + rewriter.create(loc, f32Type, pred, approx, refined); + + // Replace with trucation back to fp16. + rewriter.replaceOpWithNewOp(op, op.getType(), result); + + return success(); +} + +void NVVMOptimizeForTarget::runOnOperation() { + MLIRContext *ctx = getOperation()->getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr NVVM::createOptimizeForTargetPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm-optimize.mlir b/mlir/test/Dialect/LLVMIR/nvvm-optimize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-optimize.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s + +// CHECK-LABEL: llvm.func @fdiv_fp16 +llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK-DAG: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 + // CHECK-DAG: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 + // CHECK-DAG: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 + // CHECK-DAG: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 + // CHECK-DAG: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %[[rhs]] : f32 + // CHECK-DAG: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 + // CHECK-DAG: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 + // CHECK-DAG: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 + // CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 + // CHECK-DAG: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 + // CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 + // CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 + // CHECK-DAG: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 + // CHECK-DAG: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 + // CHECK-DAG: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 + %result = llvm.fdiv %arg0, %arg1 : f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -29,6 +29,13 @@ llvm.return %0 : i32 } +// CHECK-LABEL: @nvvm_rcp +func.func @nvvm_rcp(%arg0: f32) -> f32 { + // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 + %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 + llvm.return %0 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -33,6 +33,13 @@ llvm.return %1 : i32 } +// CHECK-LABEL: @nvvm_rcp +llvm.func @nvvm_rcp(%0: f32) -> f32 { + // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f + %1 = nvvm.rcp.approx.ftz.f %0 : f32 + llvm.return %1 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3379,7 +3379,9 @@ ":IR", ":LLVMDialect", ":LLVMPassIncGen", + ":NVVMDialect", ":Pass", + ":Transforms", ], )