diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -12,9 +12,13 @@ #include "mlir/Pass/Pass.h" namespace mlir { + +class TypeConverter; + namespace arith { #define GEN_PASS_DECL_ARITHMETICBUFFERIZE +#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT #define GEN_PASS_DECL_ARITHMETICEXPANDOPS #define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT #include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" @@ -25,6 +29,16 @@ /// Create a pass to bufferize arith.constant ops. std::unique_ptr createConstantBufferizePass(uint64_t alignment = 0); +/// Creates a pass to emulate 2*N-bit integer operations with N-bit operations. +std::unique_ptr +createEmulateWideIntPass(unsigned widestIntSupported = 32); + +std::unique_ptr +createWideIntEmulationTypeConverter(unsigned widestIntSupported = 32); + +void populateWideIntEmulationPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Add patterns to expand Arithmetic ops for LLVM lowering. void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -49,4 +49,22 @@ let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()"; } +def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> { + let summary = "Emulate 2*N-bit integer operations using N-bit operations"; + let description = [{ + Emulate integer operations that use too wide integer types with equivalent + operations on supported integer types. + This pass is intended preserve semantics but not necessarily provide the + most efficient implementation. + + Currently, only power-of-two integer types are supported. + }]; + let constructor = "mlir::arith::createEmulateWideIntPass()"; + let options = [ + Option<"widestIntSupported", "widest-int-supported", "unsigned", /*default=*/"32", + "Widest integer type supported by the target">, + ]; + let dependentDialects = ["vector::VectorDialect"]; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmulateWideInt.cpp ExpandOps.cpp UnsignedWhenEquivalent.cpp diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -0,0 +1,175 @@ +//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with +// unsigned +// ones when all their arguments and results are statically non-negative --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +#include + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" +} // namespace mlir::arith + +using namespace mlir; + +namespace { +// Converts integer types that are too wide for the target to supported ones. +// Currently, we only handle power-of-two integer types and support conversions +// of integers twice as wide as the maxium supported. Wide integers are +// represented as vectors, e.g., i64 --> vector<2xi32>, where the first element +// is the low half of the original integer, and the second element the high +// half. +class WideIntEmulationConverter final : public TypeConverter { +public: + explicit WideIntEmulationConverter(unsigned widestIntSupported) + : maxIntWidth(widestIntSupported) { + // Scalar case. + addConversion([widestInt = + widestIntSupported](IntegerType ty) -> Optional { + const unsigned width = ty.getWidth(); + if (width <= widestInt) + return ty; + + // i2N --> vector<2xiN> + if (width == 2 * widestInt) + return VectorType::get({2}, + IntegerType::get(ty.getContext(), widestInt)); + + return None; + }); + + // Vector case. + addConversion([widestInt = + widestIntSupported](VectorType ty) -> Optional { + if (auto intTy = ty.getElementType().dyn_cast()) { + const unsigned width = intTy.getWidth(); + if (width <= widestInt) + return ty; + + // vector<...xi2N> --> vector<2x...xiN> + if (width == 2 * widestInt) { + SmallVector newShape = {2}; + llvm::append_range(newShape, ty.getShape()); + return VectorType::get(newShape, + IntegerType::get(ty.getContext(), widestInt)); + } + + return None; + } + return ty; + }); + + // Function case. + addConversion([this](FunctionType ty) -> Optional { + // Convert inputs and results, e.g.: + // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN> + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return None; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return None; + + return FunctionType::get(ty.getContext(), inputs, results); + }); + } + + unsigned getMaxIntegerWidth() const { return maxIntWidth; } + +private: + unsigned maxIntWidth; +}; + +struct EmulateWideIntPass final + : arith::impl::ArithmeticEmulateWideIntBase { + EmulateWideIntPass(unsigned widestIntSupported) { + this->widestIntSupported.setValue(widestIntSupported); + } + + /// Implementation structure: first find all equivalent ops and collect them, + /// then perform all the rewrites in a second pass over the target op. This + /// ensures that analysis results are not invalidated during rewriting. + void runOnOperation() override { + if (!llvm::isPowerOf2_32(widestIntSupported)) { + assert(false && "Widest int supported is not a power of two"); + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + WideIntEmulationConverter typeConverter(widestIntSupported); + auto addUnrealizedCast = [](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + + ConversionTarget target(*ctx); + // clang-format off + target.addDynamicallyLegalOp< + // func ops + func::FuncOp, func::CallOp, func::ReturnOp + >( + // clang-format on + [&typeConverter](Operation *op) { + if (auto func = dyn_cast(op)) + return typeConverter.isLegal(func.getFunctionType()); + + return typeConverter.isLegal(op); + }); + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + RewritePatternSet patterns(ctx); + arith::populateWideIntEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +namespace mlir::arith { + +void populateWideIntEmulationPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} + +std::unique_ptr createEmulateWideIntPass(unsigned widestIntSupported) { + return std::make_unique(widestIntSupported); +} + +std::unique_ptr +createWideIntEmulationTypeConverter(unsigned widestIntSupported) { + return std::make_unique(widestIntSupported); +} + +} // namespace mlir::arith diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @addi_same_i32 +// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32 +// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32 +// CHECK-NEXT: return [[X]] : i32 +func.func @addi_same_i32(%a : i32) -> i32 { + %x = arith.addi %a, %a : i32 + return %x : i32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @addi_same_vector_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32> +// CHECK-NEXT: return [[X]] : vector<2xi32> +func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> { + %x = arith.addi %a, %a : vector<2xi32> + return %x : vector<2xi32> +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<2x4xi32>) -> vector<2x4xi32> +// CHECK-NEXT: return [[ARG]] : vector<2x4xi32> +func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> { + return %x : vector<4xi64> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<2x3x4xi32>) -> vector<2x3x4xi32> +// CHECK-NEXT: return [[ARG]] : vector<2x3x4xi32> +func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> { + return %x : vector<3x4xi64> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<2x4xi32>) -> vector<2x4xi32> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<2x4xi32>) -> vector<2x4xi32> +// CHECK-NEXT: return [[RES]] : vector<2x4xi32> +func.func @call(%a : vector<4xi64>) -> vector<4xi64> { + %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64> + return %res : vector<4xi64> +}