diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -15,16 +15,25 @@ namespace arith { #define GEN_PASS_DECL_ARITHMETICBUFFERIZE +#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT #define GEN_PASS_DECL_ARITHMETICEXPANDOPS #define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT #include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" +class WideIntEmulationConverter; + /// Create a pass to bufferize Arithmetic ops. std::unique_ptr createArithmeticBufferizePass(); /// Create a pass to bufferize arith.constant ops. std::unique_ptr createConstantBufferizePass(uint64_t alignment = 0); +/// Adds patterns to emulate wide Arithmetic and Function ops over integer +/// types into supported ones. This is done by splitting original power-of-two +/// i2N integer types into two iN halves. +void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, + RewritePatternSet &patterns); + /// Add patterns to expand Arithmetic ops for LLVM lowering. void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -49,4 +49,24 @@ let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()"; } +def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> { + let summary = "Emulate 2*N-bit integer operations using N-bit operations"; + let description = [{ + Emulate integer operations that use too wide integer types with equivalent + operations on supported narrow integer types. This is done by splitting + original integer values into two halves. + + This pass is intended preserve semantics but not necessarily provide the + most efficient implementation. + TODO: Optimize op emulation. + + Currently, only power-of-two integer bitwidths are supported. + }]; + let options = [ + Option<"widestIntSupported", "widest-int-supported", "unsigned", + /*default=*/"32", "Widest integer type supported by the target">, + ]; + let dependentDialects = ["vector::VectorDialect"]; +} + #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h @@ -0,0 +1,34 @@ +//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_ +#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::arith { +/// Converts integer types that are too wide for the target by splitting them in +/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N +/// is the widest integer bitwidth supported by the target. +/// Currently, we only handle power-of-two integer types and support conversions +/// of integers twice as wide as the maxium supported by the target. Wide +/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the +/// first element is the low half of the original integer, and the second +/// element the high half. +class WideIntEmulationConverter : public TypeConverter { +public: + explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget); + + unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; } + +private: + unsigned maxIntWidth; +}; +} // namespace mlir::arith + +#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_ diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmulateWideInt.cpp ExpandOps.cpp UnsignedWhenEquivalent.cpp diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -0,0 +1,120 @@ +//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" + +#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/MathExtras.h" +#include + +namespace mlir::arith { +#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" +} // namespace mlir::arith + +using namespace mlir; + +namespace { +struct EmulateWideIntPass final + : arith::impl::ArithmeticEmulateWideIntBase { + using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase; + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(widestIntSupported)) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::WideIntEmulationConverter typeConverter(widestIntSupported); + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + target.addDynamicallyLegalOp( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + arith::populateWideIntEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +arith::WideIntEmulationConverter::WideIntEmulationConverter( + unsigned widestIntSupportedByTarget) + : maxIntWidth(widestIntSupportedByTarget) { + assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) && + "Only power-of-two integers are supported"); + + // Scalar case. + addConversion([this](IntegerType ty) -> Optional { + unsigned width = ty.getWidth(); + if (width <= maxIntWidth) + return ty; + + // i2N --> vector<2xiN> + if (width == 2 * maxIntWidth) + return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth)); + + return None; + }); + + // Vector case. + addConversion([this](VectorType ty) -> Optional { + auto intTy = ty.getElementType().dyn_cast(); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width <= maxIntWidth) + return ty; + + // vector<...xi2N> --> vector<...x2xiN> + if (width == 2 * maxIntWidth) { + auto newShape = to_vector(ty.getShape()); + newShape.push_back(2); + return VectorType::get(newShape, + IntegerType::get(ty.getContext(), maxIntWidth)); + } + + return None; + }); + + // Function case. + addConversion([this](FunctionType ty) -> Optional { + // Convert inputs and results, e.g.: + // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN> + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return None; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return None; + + return FunctionType::get(ty.getContext(), inputs, results); + }); +} + +void arith::populateWideIntEmulationPatterns( + WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @addi_same_i32 +// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32 +// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32 +// CHECK-NEXT: return [[X]] : i32 +func.func @addi_same_i32(%a : i32) -> i32 { + %x = arith.addi %a, %a : i32 + return %x : i32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @addi_same_vector_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32> +// CHECK-NEXT: return [[X]] : vector<2xi32> +func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> { + %x = arith.addi %a, %a : vector<2xi32> + return %x : vector<2xi32> +} + +// CHECK-LABEL: func @identity_scalar +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: return [[ARG]] : vector<2xi32> +func.func @identity_scalar(%x : i64) -> i64 { + return %x : i64 +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32> +// CHECK-NEXT: return [[ARG]] : vector<4x2xi32> +func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> { + return %x : vector<4xi64> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32> +// CHECK-NEXT: return [[ARG]] : vector<3x4x2xi32> +func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> { + return %x : vector<3x4xi64> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32> +// CHECK-NEXT: return [[RES]] : vector<4x2xi32> +func.func @call(%a : vector<4xi64>) -> vector<4xi64> { + %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64> + return %res : vector<4xi64> +}