diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -13,6 +13,8 @@ namespace mlir { class DataFlowSolver; +class ConversionTarget; +class TypeConverter; namespace arith { @@ -42,6 +44,21 @@ void populateArithNarrowTypeEmulationPatterns( NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Populate the type conversions needed to emulate the unsupported +/// `sourceTypes` with `destType` +void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, + ArrayRef sourceTypes, + Type targetType); + +/// Add rewrite patterns for converting operations that use illegal float types +/// to ones that use legal ones. +void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, + TypeConverter &converter); + +/// Set up a dialect conversion to reject arithmetic operations on unsupported +/// float types. +void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, + TypeConverter &converter); /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -63,6 +63,28 @@ }]; } +def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> { + let summary = "Emulate operations on unsupported floats with extf/truncf"; + let description = [{ + Emulate arith and vector floating point operations that use float types + which are unspported on a target by inserting extf/truncf pairs around all + such operations in order to produce arithmetic that can be performed while + preserving the original rounding behavior. + + This pass does not attempt to reason about the operations being performed + to determine when type conversions can be elided. + }]; + + let options = [ + ListOption<"sourceTypeStrs", "source-types", "std::string", + "MLIR types without arithmetic support on a given target">, + Option<"targetTypeStr", "target-type", "std::string", "\"f32\"", + "MLIR type to convert the unsupported source types to">, + ]; + + let dependentDialects = ["vector::VectorDialect"]; +} + def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> { let summary = "Emulate 2*N-bit integer operations using N-bit operations"; let description = [{ diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArithTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmulateUnsupportedFloats.cpp EmulateWideInt.cpp EmulateNarrowType.cpp ExpandOps.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -0,0 +1,184 @@ +//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This pass promotes small floats (of some unsupported types T) to a supported +// type U by wrapping all float operations on Ts with expansion to and +// truncation from U, then operating on U. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +} // namespace mlir::arith + +using namespace mlir; + +namespace { +struct EmulateUnsupportedFloatsPass + : arith::impl::ArithEmulateUnsupportedFloatsBase< + EmulateUnsupportedFloatsPass> { + using arith::impl::ArithEmulateUnsupportedFloatsBase< + EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase; + + void runOnOperation() override; +}; + +struct EmulateFloatPattern final : ConversionPattern { + EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx) + : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} + + LogicalResult match(Operation *op) const override; + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // end namespace + +/// Map strings to float types. This function is here because no one else needs +/// it yet, feel free to abstract it out. +static std::optional parseFloatType(MLIRContext *ctx, + StringRef name) { + Builder b(ctx); + return llvm::StringSwitch>(name) + .Case("f8E5M2", b.getFloat8E5M2Type()) + .Case("f8E4M3FN", b.getFloat8E4M3FNType()) + .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType()) + .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType()) + .Case("bf16", b.getBF16Type()) + .Case("f16", b.getF16Type()) + .Case("f32", b.getF32Type()) + .Case("f64", b.getF64Type()) + .Case("f80", b.getF80Type()) + .Case("f128", b.getF128Type()) + .Default(std::nullopt); +} + +LogicalResult EmulateFloatPattern::match(Operation *op) const { + if (getTypeConverter()->isLegal(op)) + return failure(); + // The rewrite doesn't handle cloning regions. + if (op->getNumRegions() != 0) + return failure(); + return success(); +} + +void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + TypeConverter *converter = getTypeConverter(); + SmallVector resultTypes; + assert( + succeeded(converter->convertTypes(op->getResultTypes(), resultTypes)) && + "type conversions shouldn't fail in this pass"); + Operation *expandedOp = + rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, + op->getAttrs(), op->getSuccessors(), /*regions=*/{}); + SmallVector newResults(expandedOp->getResults()); + for (auto [res, oldType, newType] : llvm::zip_equal( + MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { + if (oldType != newType) + res = rewriter.create(loc, oldType, res); + } + rewriter.replaceOp(op, newResults); +} + +void mlir::arith::populateEmulateUnsupportedFloatsConversions( + TypeConverter &converter, ArrayRef sourceTypes, Type targetType) { + converter.addConversion([sourceTypes = SmallVector(sourceTypes), + targetType](Type type) -> std::optional { + if (llvm::is_contained(sourceTypes, type)) + return targetType; + if (auto shaped = type.dyn_cast()) + if (llvm::is_contained(sourceTypes, shaped.getElementType())) + return shaped.clone(targetType); + // All other types legal + return type; + }); + converter.addTargetMaterialization( + [](OpBuilder &b, Type target, ValueRange input, Location loc) { + return b.create(loc, target, input); + }); +} + +void mlir::arith::populateEmulateUnsupportedFloatsPatterns( + RewritePatternSet &patterns, TypeConverter &converter) { + patterns.add(converter, patterns.getContext()); +} + +void mlir::arith::populateEmulateUnsupportedFloatsLegality( + ConversionTarget &target, TypeConverter &converter) { + // Don't try to legalize functions and other ops that don't need expansion. + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); + target.addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + // Manually mark arithmetic-performing vector instructions. + target.addDynamicallyLegalOp< + vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, + vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( + [&](Operation *op) { return converter.isLegal(op); }); + target.addLegalOp(); +} + +void EmulateUnsupportedFloatsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + SmallVector sourceTypes; + Type targetType; + + std::optional maybeTargetType = parseFloatType(ctx, targetTypeStr); + if (!maybeTargetType) { + emitError(UnknownLoc::get(ctx), "could not map target type '" + + targetTypeStr + + "' to a known floating-point type"); + return signalPassFailure(); + } + targetType = *maybeTargetType; + for (StringRef sourceTypeStr : sourceTypeStrs) { + std::optional maybeSourceType = + parseFloatType(ctx, sourceTypeStr); + if (!maybeSourceType) { + emitError(UnknownLoc::get(ctx), "could not map source type '" + + sourceTypeStr + + "' to a known floating-point type"); + return signalPassFailure(); + } + sourceTypes.push_back(*maybeSourceType); + } + if (sourceTypes.empty()) + (void)emitOptionalWarning( + std::nullopt, + "no source types specified, float emulation will do nothing"); + + if (llvm::is_contained(sourceTypes, targetType)) { + emitError(UnknownLoc::get(ctx), + "target type cannot be an unsupported source type"); + return signalPassFailure(); + } + TypeConverter converter; + arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, + targetType); + RewritePatternSet patterns(ctx); + arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); + ConversionTarget target(getContext()); + arith::populateEmulateUnsupportedFloatsLegality(target, converter); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s + +func.func @basic_expansion(%x: bf16) -> bf16 { +// CHECK-LABEL: @basic_expansion +// CHECK-SAME: [[X:%.+]]: bf16 +// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 +// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32 +// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 +// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16 +// CHECK: return [[Y]] + %c = arith.constant 1.0 : bf16 + %y = arith.addf %x, %c : bf16 + func.return %y : bf16 +} + +// ----- + +func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { +// CHECK-LABEL: @chained +// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 +// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32 +// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32 +// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 +// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16 +// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32 +// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] +// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16 +// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32 +// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 +// CHECK: return [[RES]] + %p = arith.addf %x, %y : bf16 + %q = arith.mulf %p, %z : bf16 + %res = arith.cmpf ole, %p, %q : bf16 + func.return %res : i1 +} + +// ----- + +func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { +// CHECK-LABEL: @memops +// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> +// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32 +// CHECK: memref.store [[V]] +// CHECK: [[W:%.+]] = memref.load +// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32 +// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 +// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ +// CHECK: memref.store [[X]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> + memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> + %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> + %x = arith.addf %v, %w : f8E4M3FNUZ + memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> + func.return +} + +// ----- + +func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { +// CHECK-LABEL: @vectors +// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> +// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> +// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ> +// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: return [[RET]] + %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> + %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> + func.return %ret : vector<4xf32> +}