diff --git a/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -0,0 +1,56 @@ +//===- Arithmetic.h - Arithmetic dialect --------------------------*- C++-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_ +#define MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/VectorInterfaces.h" + +//===----------------------------------------------------------------------===// +// ArithmeticDialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// Arithmetic Dialect Enum Attributes +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc" + +//===----------------------------------------------------------------------===// +// Arithmetic Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc" + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace arith { + +/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +/// comparison predicates. +bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, + const APInt &rhs); + +/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +/// comparison predicates. +bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, + const APFloat &rhs); + +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_ diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -0,0 +1,68 @@ +//===- ArithmeticBase.td - Base defs for arith dialect ------*- tablegen -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ARITHMETIC_BASE +#define ARITHMETIC_BASE + +include "mlir/IR/OpBase.td" + +def Arithmetic_Dialect : Dialect { + let name = "arith"; + let cppNamespace = "::mlir::arith"; + let description = [{ + The arithmetic dialect is intended to hold basic integer and floating point + mathematical operations. This includes unary, binary, and ternary arithmetic + ops, bitwise and shift ops, cast ops, and compare ops. Operations in this + dialect also accept vectors and tensors of integers or floats. + }]; +} + +// The predicate indicates the type of the comparison to perform: +// (un)orderedness, (in)equality and less/greater than (or equal to) as +// well as predicates that are always true or false. +def Arith_CmpFPredicateAttr : I64EnumAttr< + "CmpFPredicate", "", + [ + I64EnumAttrCase<"AlwaysFalse", 0, "false">, + I64EnumAttrCase<"OEQ", 1, "oeq">, + I64EnumAttrCase<"OGT", 2, "ogt">, + I64EnumAttrCase<"OGE", 3, "oge">, + I64EnumAttrCase<"OLT", 4, "olt">, + I64EnumAttrCase<"OLE", 5, "ole">, + I64EnumAttrCase<"ONE", 6, "one">, + I64EnumAttrCase<"ORD", 7, "ord">, + I64EnumAttrCase<"UEQ", 8, "ueq">, + I64EnumAttrCase<"UGT", 9, "ugt">, + I64EnumAttrCase<"UGE", 10, "uge">, + I64EnumAttrCase<"ULT", 11, "ult">, + I64EnumAttrCase<"ULE", 12, "ule">, + I64EnumAttrCase<"UNE", 13, "une">, + I64EnumAttrCase<"UNO", 14, "uno">, + I64EnumAttrCase<"AlwaysTrue", 15, "true">, + ]> { + let cppNamespace = "::mlir::arith"; +} + +def Arith_CmpIPredicateAttr : I64EnumAttr< + "CmpIPredicate", "", + [ + I64EnumAttrCase<"eq", 0>, + I64EnumAttrCase<"ne", 1>, + I64EnumAttrCase<"slt", 2>, + I64EnumAttrCase<"sle", 3>, + I64EnumAttrCase<"sgt", 4>, + I64EnumAttrCase<"sge", 5>, + I64EnumAttrCase<"ult", 6>, + I64EnumAttrCase<"ule", 7>, + I64EnumAttrCase<"ugt", 8>, + I64EnumAttrCase<"uge", 9>, + ]> { + let cppNamespace = "::mlir::arith"; +} + +#endif // ARITHMETIC_BASE diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -0,0 +1,997 @@ +//===- ArithmeticOps.td - Arithmetic op definitions --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ARITHMETIC_OPS +#define ARITHMETIC_OPS + +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/VectorInterfaces.td" + +// Base class for Arithmetic dialect ops. Ops in this dialect have no side +// effects and can be applied element-wise to vectors and tensors. +class Arith_Op traits = []> : + Op] # + ElementwiseMappable.traits>; + +// Base class for integer and floating point arithmetic ops. All ops have one +// result, require operands and results to be of the same type, and can accept +// tensors or vectors of integers or floats. +class Arith_ArithmeticOp traits = []> : + Arith_Op; + +// Base class for unary arithmetic operations. +class Arith_UnaryOp traits = []> : + Arith_ArithmeticOp { + let assemblyFormat = "$operand attr-dict `:` type($result)"; +} + +// Base class for binary arithmetic operations. +class Arith_BinaryOp traits = []> : + Arith_ArithmeticOp { + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; +} + +// Base class for ternary arithmetic operations. +class Arith_TernaryOp traits = []> : + Arith_ArithmeticOp { + let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; +} + +// Base class for integer binary operations. +class Arith_IntBinaryOp traits = []> : + Arith_BinaryOp, + Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, + Results<(outs SignlessIntegerLike:$result)>; + +// Base class for floating point unary operations. +class Arith_FloatUnaryOp traits = []> : + Arith_UnaryOp, + Arguments<(ins FloatLike:$operand)>, + Results<(outs FloatLike:$result)>; + +// Base class for floating point binary operations. +class Arith_FloatBinaryOp traits = []> : + Arith_BinaryOp, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>, + Results<(outs FloatLike:$result)>; + +// Base class for arithmetic cast operations. Requires a single operand and +// result. If either is a shaped type, then the other must be of the same shape. +class Arith_CastOp traits = []> : + Arith_Op]>, + Arguments<(ins From:$in)>, + Results<(outs To:$out)> { + let builders = [ + OpBuilder<(ins "Value":$source, "Type":$destType), [{ + impl::buildCastOp($_builder, $_state, source, destType); + }]> + ]; + + let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)"; +} + +// Casts do not accept indices. Type constraint for signless-integer-like types +// excluding indices: signless integers, vectors or tensors thereof. +def SignlessFixedWidthIntegerLike : TypeConstraint.predicate, + TensorOf<[AnySignlessInteger]>.predicate]>, + "signless-fixed-width-integer-like">; + +// Cast from an integer type to another integer type. +class Arith_IToICastOp traits = []> : + Arith_CastOp; +// Cast from an integer type to a floating point type. +class Arith_IToFCastOp traits = []> : + Arith_CastOp; +// Cast from a floating point type to an integer type. +class Arith_FToICastOp traits = []> : + Arith_CastOp; +// Cast from a floating point type to another floating point type. +class Arith_FToFCastOp traits = []> : + Arith_CastOp; + +// Base class for compare operations. Requires two operands of the same type +// and returns a single `BoolLike` result. If the operand type is a vector or +// tensor, then the result will be one of `i1` of the same shape. +class Arith_CompareOp traits = []> : + Arith_Op]> { + let results = (outs BoolLike:$result); + + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +def Arith_ConstantOp : Op]> { + let summary = "integer or floating point constant"; + let description = [{ + The `const` operation produces an SSA value equal to some integer or + floating-point constant specified by an attribute. This is the way MLIR + forms simple integer and floating point constants. + + Example: + + ``` + // Integer constant + %1 = arith.constant 42 : i32 + + // Equivalent generic form + %1 = "arith.constant"() {value = 42 : i32} : () -> i32 + ``` + }]; + + let arguments = (ins AnyAttr:$value); + let results = (outs SignlessIntegerOrFloatLike:$result); + + let builders = [ + OpBuilder<(ins "Attribute":$value), + [{ build($_builder, $_state, value.getType(), value); }]>, + OpBuilder<(ins "Attribute":$value, "Type":$type), + [{ build($_builder, $_state, type, value); }]>, + ]; + + let hasFolder = 1; + let assemblyFormat = "attr-dict $value"; +} + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> { + let summary = "integer addition operation"; + let description = [{ + The `addi` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be an integer scalar + type, a vector whose element type is integer, or a tensor of integers. It + has no standard attributes. + + Example: + + ```mlir + // Scalar addition. + %a = arith.addi %b, %c : i64 + + // SIMD vector element-wise addition, e.g. for Intel SSE. + %f = arith.addi %g, %h : vector<4xi32> + + // Tensor element-wise addition. + %x = arith.addi %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +def Arith_SubIOp : Arith_IntBinaryOp<"subi"> { + let summary = "integer subtraction operation"; + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +def Arith_MulIOp : Arith_IntBinaryOp<"muli", [Commutative]> { + let summary = "integer multiplication operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// DivUIOp +//===----------------------------------------------------------------------===// + +def Arith_DivUIOp : Arith_IntBinaryOp<"divui"> { + let summary = "unsigned integer division operation"; + let description = [{ + Unsigned integer division. Rounds towards zero. Treats the leading bit as + the most significant, i.e. for `i16` given two's complement representation, + `6 / -2 = 6 / (2^16 - 2) = 0`. + + Note: the semantics of division by zero is TBD; do NOT assume any specific + behavior. + + Example: + + ```mlir + // Scalar unsigned integer division. + %a = arith.divui %b, %c : i64 + + // SIMD vector element-wise division. + %f = arith.divui %g, %h : vector<4xi32> + + // Tensor element-wise integer division. + %x = arith.divui %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// DivSIOp +//===----------------------------------------------------------------------===// + +def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> { + let summary = "signed integer division operation"; + let description = [{ + Signed integer division. Rounds towards zero. Treats the leading bit as + sign, i.e. `6 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = arith.divsi %b, %c : i64 + + // SIMD vector element-wise division. + %f = arith.divsi %g, %h : vector<4xi32> + + // Tensor element-wise integer division. + %x = arith.divsi %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// CeilDivSIOp +//===----------------------------------------------------------------------===// + +def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi"> { + let summary = "signed ceil integer division operation"; + let description = [{ + Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = arith.ceildivsi %b, %c : i64 + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// FloorDivSIOp +//===----------------------------------------------------------------------===// + +def Arith_FloorDivSIOp : Arith_IntBinaryOp<"floordivsi"> { + let summary = "signed floor integer division operation"; + let description = [{ + Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = arith.floordivsi %b, %c : i64 + + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// RemUIOp +//===----------------------------------------------------------------------===// + +def Arith_RemUIOp : Arith_IntBinaryOp<"remui"> { + let summary = "unsigned integer division remainder operation"; + let description = [{ + Unsigned integer division remainder. Treats the leading bit as the most + significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. + + Note: the semantics of division by zero is TBD; do NOT assume any specific + behavior. + + Example: + + ```mlir + // Scalar unsigned integer division remainder. + %a = arith.remui %b, %c : i64 + + // SIMD vector element-wise division remainder. + %f = arith.remui %g, %h : vector<4xi32> + + // Tensor element-wise integer division remainder. + %x = arith.remui %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// RemSIOp +//===----------------------------------------------------------------------===// + +def Arith_RemSIOp : Arith_IntBinaryOp<"remsi"> { + let summary = "signed integer division remainder operation"; + let description = [{ + Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % + -2 = 0`. + + Note: the semantics of division by zero is TBD; do NOT assume any specific + behavior. + + Example: + + ```mlir + // Scalar signed integer division remainder. + %a = remsi %b, %c : i64 + + // SIMD vector element-wise division remainder. + %f = remsi %g, %h : vector<4xi32> + + // Tensor element-wise integer division remainder. + %x = remsi %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// AndIOp +//===----------------------------------------------------------------------===// + +def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> { + let summary = "integer binary and"; + let description = [{ + The `andi` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be an integer scalar + type, a vector whose element type is integer, or a tensor of integers. It + has no standard attributes. + + Example: + + ```mlir + // Scalar integer bitwise and. + %a = arith.andi %b, %c : i64 + + // SIMD vector element-wise bitwise integer and. + %f = arith.andi %g, %h : vector<4xi32> + + // Tensor element-wise bitwise integer and. + %x = arith.andi %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// OrIOp +//===----------------------------------------------------------------------===// + +def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative]> { + let summary = "integer binary or"; + let description = [{ + The `ori` operation takes two operands and returns one result, each of these + is required to be the same type. This type may be an integer scalar type, a + vector whose element type is integer, or a tensor of integers. It has no + standard attributes. + + Example: + + ```mlir + // Scalar integer bitwise or. + %a = arith.ori %b, %c : i64 + + // SIMD vector element-wise bitwise integer or. + %f = arith.ori %g, %h : vector<4xi32> + + // Tensor element-wise bitwise integer or. + %x = arith.ori %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// XOrIOp +//===----------------------------------------------------------------------===// + +def Arith_XOrIOp : Arith_IntBinaryOp<"xori", [Commutative]> { + let summary = "integer binary xor"; + let description = [{ + The `xori` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be an integer scalar + type, a vector whose element type is integer, or a tensor of integers. It + has no standard attributes. + + Example: + + ```mlir + // Scalar integer bitwise xor. + %a = arith.xori %b, %c : i64 + + // SIMD vector element-wise bitwise integer xor. + %f = arith.xori %g, %h : vector<4xi32> + + // Tensor element-wise bitwise integer xor. + %x = arith.xori %y, %z : tensor<4x?xi8> + ``` + }]; + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// ShLIOp +//===----------------------------------------------------------------------===// + +def Arith_ShLIOp : Arith_IntBinaryOp<"shli"> { + let summary = "integer left-shift"; + let description = [{ + The `shli` operation shifts an integer value to the left by a variable + amount. The low order bits are filled with zeros. + + Example: + + ```mlir + %1 = arith.constant 5 : i8 // %1 is 0b00000101 + %2 = arith.constant 3 : i8 + %3 = arith.shli %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// ShRUIOp +//===----------------------------------------------------------------------===// + +def Arith_ShRUIOp : Arith_IntBinaryOp<"shrui"> { + let summary = "unsigned integer right-shift"; + let description = [{ + The `shrui` operation shifts an integer value to the right by a variable + amount. The integer is interpreted as unsigned. The high order bits are + always filled with zeros. + + Example: + + ```mlir + %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %2 = arith.constant 3 : i8 + %3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// ShRSIOp +//===----------------------------------------------------------------------===// + +def Arith_ShRSIOp : Arith_IntBinaryOp<"shrsi"> { + let summary = "signed integer right-shift"; + let description = [{ + The `shrsi` operation shifts an integer value to the right by a variable + amount. The integer is interpreted as signed. The high order bits in the + output are filled with copies of the most-significant bit of the shifted + value (which means that the sign of the value is preserved). + + Example: + + ```mlir + %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %2 = arith.constant 3 : i8 + %3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 + %4 = arith.constant 96 : i8 // %4 is 0b01100000 + %5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// NegFOp +//===----------------------------------------------------------------------===// + +def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> { + let summary = "floating point negation"; + let description = [{ + The `negf` operation computes the negation of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + + Example: + + ```mlir + // Scalar negation value. + %a = arith.negf %b : f64 + + // SIMD vector element-wise negation value. + %f = arith.negf %g : vector<4xf32> + + // Tensor element-wise negation value. + %x = arith.negf %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + +def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> { + let summary = "floating point addition operation"; + let description = [{ + The `addf` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be a floating point + scalar type, a vector whose element type is a floating point type, or a + floating point tensor. + + Example: + + ```mlir + // Scalar addition. + %a = arith.addf %b, %c : f64 + + // SIMD vector addition, e.g. for Intel SSE. + %f = arith.addf %g, %h : vector<4xf32> + + // Tensor addition. + %x = arith.addf %y, %z : tensor<4x?xbf16> + ``` + + TODO: In the distant future, this will accept optional attributes for fast + math, contraction, rounding mode, and other controls. + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + +def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> { + let summary = "floating point subtraction operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +def Arith_MulFOp : Arith_FloatBinaryOp<"mulf"> { + let summary = "floating point multiplication operation"; + let description = [{ + The `mulf` operation takes two operands and returns one result, each of + these is required to be the same type. This type may be a floating point + scalar type, a vector whose element type is a floating point type, or a + floating point tensor. + + Example: + + ```mlir + // Scalar multiplication. + %a = arith.mulf %b, %c : f64 + + // SIMD pointwise vector multiplication, e.g. for Intel SSE. + %f = arith.mulf %g, %h : vector<4xf32> + + // Tensor pointwise multiplication. + %x = arith.mulf %y, %z : tensor<4x?xbf16> + ``` + + TODO: In the distant future, this will accept optional attributes for fast + math, contraction, rounding mode, and other controls. + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// DivFOp +//===----------------------------------------------------------------------===// + +def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> { + let summary = "floating point division operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// RemFOp +//===----------------------------------------------------------------------===// + +def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> { + let summary = "floating point division remainder operation"; +} + +//===----------------------------------------------------------------------===// +// ExtUIOp +//===----------------------------------------------------------------------===// + +def Arith_ExtUIOp : Arith_IToICastOp<"extui"> { + let summary = "integer zero extension operation"; + let description = [{ + The integer zero extension operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be larger than the input bit-width (N > M). + The top-most (N - M) bits of the output are filled with zeros. + + Example: + + ```mlir + %1 = arith.constant 5 : i3 // %1 is 0b101 + %2 = arith.extui %1 : i3 to i6 // %2 is 0b000101 + %3 = arith.constant 2 : i3 // %3 is 0b010 + %4 = arith.extui %3 : i3 to i6 // %4 is 0b000010 + + %5 = arith.extui %0 : vector<2 x i32> to vector<2 x i64> + ``` + }]; + + let hasFolder = 1; + let verifier = [{ return verifyExtOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// ExtSIOp +//===----------------------------------------------------------------------===// + +def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> { + let summary = "integer sign extension operation"; + + let description = [{ + The integer sign extension operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be larger than the input bit-width (N > M). + The top-most (N - M) bits of the output are filled with copies + of the most-significant bit of the input. + + Example: + + ```mlir + %1 = arith.constant 5 : i3 // %1 is 0b101 + %2 = arith.extsi %1 : i3 to i6 // %2 is 0b111101 + %3 = arith.constant 2 : i3 // %3 is 0b010 + %4 = arith.extsi %3 : i3 to i6 // %4 is 0b000010 + + %5 = arith.extsi %0 : vector<2 x i32> to vector<2 x i64> + ``` + }]; + + let hasFolder = 1; + let verifier = [{ return verifyExtOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// ExtFOp +//===----------------------------------------------------------------------===// + +def Arith_ExtFOp : Arith_FToFCastOp<"extf"> { + let summary = "cast from floating-point to wider floating-point"; + let description = [{ + Cast a floating-point value to a larger floating-point-typed value. + The destination type must to be strictly wider than the source type. + When operating on vectors, casts elementwise. + }]; + + let verifier = [{ return verifyExtOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// TruncIOp +//===----------------------------------------------------------------------===// + +def Arith_TruncIOp : Arith_IToICastOp<"trunci"> { + let summary = "integer truncation operation"; + let description = [{ + The integer truncation operation takes an integer input of + width M and an integer destination type of width N. The destination + bit-width must be smaller than the input bit-width (N < M). + The top-most (N - M) bits of the input are discarded. + + Example: + + ```mlir + %1 = arith.constant 21 : i5 // %1 is 0b10101 + %2 = trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = trunci %1 : i5 to i3 // %3 is 0b101 + + %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> + ``` + }]; + + let hasFolder = 1; + let verifier = [{ return verifyTruncateOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// TruncFOp +//===----------------------------------------------------------------------===// + +def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> { + let summary = "cast from floating-point to narrower floating-point"; + let description = [{ + Truncate a floating-point value to a smaller floating-point-typed value. + The destination type must be strictly narrower than the source type. + If the value cannot be exactly represented, it is rounded using the default + rounding mode. When operating on vectors, casts elementwise. + }]; + + let hasFolder = 1; + let verifier = [{ return verifyTruncateOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// UIToFPOp +//===----------------------------------------------------------------------===// + +def Arith_UIToFPOp : Arith_IToFCastOp<"uitofp"> { + let summary = "cast from unsigned integer type to floating-point"; + let description = [{ + Cast from a value interpreted as unsigned integer to the corresponding + floating-point value. If the value cannot be exactly represented, it is + rounded using the default rounding mode. When operating on vectors, casts + elementwise. + }]; +} + +//===----------------------------------------------------------------------===// +// SIToFPOp +//===----------------------------------------------------------------------===// + +def Arith_SIToFPOp : Arith_IToFCastOp<"sitofp"> { + let summary = "cast from integer type to floating-point"; + let description = [{ + Cast from a value interpreted as a signed integer to the corresponding + floating-point value. If the value cannot be exactly represented, it is + rounded using the default rounding mode. When operating on vectors, casts + elementwise. + }]; +} + +//===----------------------------------------------------------------------===// +// FPToUIOp +//===----------------------------------------------------------------------===// + +def Arith_FPToUIOp : Arith_FToICastOp<"fptoui"> { + let summary = "cast from floating-point type to integer type"; + let description = [{ + Cast from a value interpreted as floating-point to the nearest (rounding + towards zero) unsigned integer value. When operating on vectors, casts + elementwise. + }]; +} + +//===----------------------------------------------------------------------===// +// FPToSIOp +//===----------------------------------------------------------------------===// + +def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> { + let summary = "cast from floating-point type to integer type"; + let description = [{ + Cast from a value interpreted as floating-point to the nearest (rounding + towards zero) signed integer value. When operating on vectors, casts + elementwise. + }]; +} + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +def Arith_IndexCastOp : Arith_IToICastOp<"index_cast"> { + let summary = "cast between index and integer types"; + let description = [{ + Casts between scalar or vector integers and corresponding 'index' scalar or + vectors. Index is an integer of platform-specific bit width. If casting to + a wider integer, the value is sign-extended. If casting to a narrower + integer, the value is truncated. + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +def Arith_BitcastOp : Arith_CastOp<"bitcast", SignlessIntegerOrFloatLike, + SignlessIntegerOrFloatLike> { + let summary = "bitcast between values of equal bit width"; + let description = [{ + Bitcast an integer or floating point value to an integer or floating point + value of equal bit width. When operating on vectors, casts elementwise. + + Note that this implements a logical bitcast independent of target + endianness. This allows constant folding without target information and is + consitent with the bitcast constant folders in LLVM (see + https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) + For targets where the source and target type have the same endianness (which + is the standard), this cast will also change no bits at runtime, but it may + still require an operation, for example if the machine has different + floating point and integer register files. For targets that have a different + endianness for the source and target types (e.g. float is big-endian and + integer is little-endian) a proper lowering would add operations to swap the + order of words in addition to the bitcast. + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +def Arith_CmpIOp : Arith_CompareOp<"cmpi"> { + let summary = "integer comparison operation"; + let description = [{ + The `cmpi` operation is a generic comparison for integer-like types. Its two + arguments can be integers, vectors or tensors thereof as long as their types + match. The operation produces an i1 for the former case, a vector or a + tensor of i1 with the same shape as inputs in the other cases. + + Its first argument is an attribute that defines which type of comparison is + performed. The following comparisons are supported: + + - equal (mnemonic: `"eq"`; integer value: `0`) + - not equal (mnemonic: `"ne"`; integer value: `1`) + - signed less than (mnemonic: `"slt"`; integer value: `2`) + - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) + - signed greater than (mnemonic: `"sgt"`; integer value: `4`) + - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) + - unsigned less than (mnemonic: `"ult"`; integer value: `6`) + - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) + - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) + - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) + + The result is `1` if the comparison is true and `0` otherwise. For vector or + tensor operands, the comparison is performed elementwise and the element of + the result indicates whether the comparison is true for the operand elements + with the same indices as those of the result. + + Note: while the custom assembly form uses strings, the actual underlying + attribute has integer type (or rather enum class in C++ code) as seen from + the generic assembly form. String literals are used to improve readability + of the IR by humans. + + This operation only applies to integer-like operands, but not floats. The + main reason being that comparison operations have diverging sets of + attributes: integers require sign specification while floats require various + floating point-related particularities, e.g., `-ffast-math` behavior, + IEEE754 compliance, etc + ([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). + The type of comparison is specified as attribute to avoid introducing ten + similar operations, taking into account that they are often implemented + using the same operation downstream + ([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The + separation between signed and unsigned order comparisons is necessary + because of integers being signless. The comparison operation must know how + to interpret values with the foremost bit being set: negatives in two's + complement or large positives + ([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). + + Example: + + ```mlir + // Custom form of scalar "signed less than" comparison. + %x = arith.cmpi "slt", %lhs, %rhs : i32 + + // Generic form of the same operation. + %x = "arith.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 + + // Custom form of vector equality comparison. + %x = arith.cmpi "eq", %lhs, %rhs : vector<4xi64> + + // Generic form of the same operation. + %x = "std.arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64} + : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> + ``` + }]; + + let arguments = (ins Arith_CmpIPredicateAttr:$predicate, + SignlessIntegerLike:$lhs, + SignlessIntegerLike:$rhs); + + let builders = [ + OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ + build($_builder, $_state, ::getI1SameShape(lhs.getType()), + predicate, lhs, rhs); + }]> + ]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + CmpIPredicate getPredicate() { + return (CmpIPredicate) (*this)->getAttrOfType( + getPredicateAttrName()).getInt(); + } + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +def Arith_CmpFOp : Arith_CompareOp<"cmpf"> { + let summary = "floating-point comparison operation"; + let description = [{ + The `cmpf` operation compares its two operands according to the float + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (un)orderedness, (in)equality + and signed less/greater than (or equal to) as well as predicates that are + always true or false. The operands must have the same type, and this type + must be a float type, or a vector or tensor thereof. The result is an i1, + or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, + the operands are always treated as signed. The u prefix indicates + *unordered* comparison, not unsigned comparison, so "une" means unordered or + not equal. For the sake of readability by humans, custom assembly form for + the operation uses a string-typed attribute for the predicate. The value of + this attribute corresponds to lower-cased name of the predicate constant, + e.g., "one" means "ordered not equal". The string representation of the + attribute is merely a syntactic sugar and is converted to an integer + attribute by the parser. + + Example: + + ```mlir + %r1 = arith.cmpf "oeq" %0, %1 : f32 + %r2 = arith.cmpf "ult" %0, %1 : tensor<42x42xf64> + %r3 = "arith.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 + ``` + }]; + + let arguments = (ins Arith_CmpFPredicateAttr:$predicate, + FloatLike:$lhs, + FloatLike:$rhs); + + let builders = [ + OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ + build($_builder, $_state, ::getI1SameShape(lhs.getType()), + predicate, lhs, rhs); + }]> + ]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + CmpFPredicate getPredicate() { + return (CmpFPredicate) (*this)->getAttrOfType( + getPredicateAttrName()).getInt(); + } + }]; + + let hasFolder = 1; +} + +#endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS ArithmeticOps.td) +mlir_tablegen(ArithmeticOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(ArithmeticOpsEnums.cpp.inc -gen-enum-defs) +add_mlir_dialect(ArithmeticOps arith) +add_mlir_doc(ArithmeticOps ArithmeticOps Dialects/ -gen-dialect-doc) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(Arithmetic) add_subdirectory(Async) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -13,30 +13,72 @@ include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" -class Math_Op traits = []> - : Op; - +// Base class for math dialect ops. Ops in this dialect have no side effects and +// can be applied element-wise to vectors and tensors. +class Math_Op traits = []> : + Op] # + ElementwiseMappable.traits>; + +// Base class for unary math operations on floating point types. Require a +// operand and result of the same type. This type can be a floating point type, +// or vector or tensor thereof. class Math_FloatUnaryOp traits = []> : - Math_Op, - SameOperandsAndResultType] # ElementwiseMappable.traits> { + Math_Op { let arguments = (ins FloatLike:$operand); - let results = (outs FloatLike:$result); let assemblyFormat = "$operand attr-dict `:` type($result)"; } +// Base class for binary math operations on floating point types. Require two +// operands and one result of the same type. This type can be a floating point +// type, or a vector or tensor thereof. class Math_FloatBinaryOp traits = []> : - Math_Op, - SameOperandsAndResultType] # ElementwiseMappable.traits> { + Math_Op { let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); let results = (outs FloatLike:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; } +// Base class for floating point ternary operations. Require three operands and +// one result of the same type. This type can be a floating point type, or a +// vector or tensor thereof. +class Math_FloatTernaryOp traits = []> : + Math_Op { + let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c); + let results = (outs FloatLike:$result); + + let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; +} + +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +def Math_AbsOp : Math_FloatUnaryOp<"abs"> { + let summary = "floating point absolute-value operation"; + let description = [{ + The `abs` operation computes the absolute value. It takes one operand and + returns one result of the same type. This type may be a float scalar type, + a vector whose element type is float, or a tensor of floats. + + Example: + + ```mlir + // Scalar absolute value. + %a = math.abs %b : f64 + + // SIMD vector element-wise absolute value. + %f = math.abs %g : vector<4xf32> + + // Tensor element-wise absolute value. + %x = math.abs %y : tensor<4x?xf8> + ``` + }]; +} + //===----------------------------------------------------------------------===// // AtanOp //===----------------------------------------------------------------------===// @@ -110,6 +152,73 @@ }]; } +//===----------------------------------------------------------------------===// +// CeilOp +//===----------------------------------------------------------------------===// + +def Math_CeilOp : Math_FloatUnaryOp<"ceil"> { + let summary = "ceiling of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.ceil` ssa-use `:` type + ``` + + The `ceil` operation computes the ceiling of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + + Example: + + ```mlir + // Scalar ceiling value. + %a = math.ceil %b : f64 + + // SIMD vector element-wise ceiling value. + %f = math.ceil %g : vector<4xf32> + + // Tensor element-wise ceiling value. + %x = math.ceil %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CopySignOp +//===----------------------------------------------------------------------===// + +def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> { + let summary = "A copysign operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.copysign` ssa-use `,` ssa-use `:` type + ``` + + The `copysign` returns a value with the magnitude of the first operand and + the sign of the second operand. It takes two operands and returns one + result of the same type. This type may be a float scalar type, a vector + whose element type is float, or a tensor of floats. It has no standard + attributes. + + Example: + + ```mlir + // Scalar copysign value. + %a = math.copysign %b, %c : f64 + + // SIMD vector element-wise copysign value. + %f = math.copysign %g, %h : vector<4xf32> + + // Tensor element-wise copysign value. + %x = math.copysign %y, %z : tensor<4x?xf8> + ``` + }]; +} + //===----------------------------------------------------------------------===// // CosOp //===----------------------------------------------------------------------===// @@ -276,6 +385,77 @@ }]; } +//===----------------------------------------------------------------------===// +// FloorOp +//===----------------------------------------------------------------------===// + +def Math_FloorOp : Math_FloatUnaryOp<"floor"> { + let summary = "floor of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.floor` ssa-use `:` type + ``` + + The `floor` operation computes the floor of a given value. It takes one + operand and returns one result of the same type. This type may be a float + scalar type, a vector whose element type is float, or a tensor of floats. + It has no standard attributes. + + Example: + + ```mlir + // Scalar floor value. + %a = math.floor %b : f64 + + // SIMD vector element-wise floor value. + %f = math.floor %g : vector<4xf32> + + // Tensor element-wise floor value. + %x = math.floor %y : tensor<4x?xf8> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// FmaOp +//===----------------------------------------------------------------------===// + +def Math_FmaOp : Math_FloatTernaryOp<"fma"> { + let summary = "floating point fused multipy-add operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.fma` ssa-use `,` ssa-use `,` ssa-use `:` type + ``` + + The `fma` operation takes three operands and returns one result, each of + these is required to be the same type. This type may be a floating point + scalar type, a vector whose element type is a floating point type, or a + floating point tensor. + + Example: + + ```mlir + // Scalar fused multiply-add: d = a*b + c + %d = math.fma %a, %b, %c : f64 + + // SIMD vector fused multiply-add, e.g. for Intel SSE. + %i = math.fma %f, %g, %h : vector<4xf32> + + // Tensor fused multiply-add. + %w = math.fma %x, %y, %z : tensor<4x?xbf16> + ``` + + The semantics of the operation correspond to those of the `llvm.fma` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the + particular case of lowering to LLVM, this is guaranteed to lower + to the `llvm.fma.*` intrinsic. + }]; +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -0,0 +1,131 @@ +//===- ArithmeticPatterns.td - Arithmetic dialect patterns -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ARITHMETIC_PATTERNS +#define ARITHMETIC_PATTERNS + +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" + +// Add two integer attributes and create a new one with the result. +def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">; + +// Subtract two integer attributes and createa a new one with the result. +def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">; + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +// addi is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// addi(addi(x, c0), c1) -> addi(x, c0 + c1) +def AddIAddConstant : + Pat<(Arith_AddIOp:$res + (Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>; + +// addi(subi(x, c0), c1) -> addi(x, c1 - c0) +def AddISubConstantRHS : + Pat<(Arith_AddIOp:$res + (Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>; + +// addi(subi(c0, x), c1) -> subi(c0 + c1, x) +def AddISubConstantLHS : + Pat<(Arith_AddIOp:$res + (Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>; + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +// subi(addi(x, c0), c1) -> addi(x, c0 - c1) +def SubIRHSAddConstant : + Pat<(Arith_SubIOp:$res + (Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>; + +// subi(c1, addi(x, c0)) -> subi(c1 - c0, x) +def SubILHSAddConstant : + Pat<(Arith_SubIOp:$res + (Arith_ConstantOp APIntAttr:$c1), + (Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0))), + (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>; + +// subi(subi(x, c0), c1) -> subi(x, c0 + c1) +def SubIRHSSubConstantRHS : + Pat<(Arith_SubIOp:$res + (Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0)), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>; + +// subi(subi(c0, x), c1) -> subi(c0 - c1, x) +def SubIRHSSubConstantLHS : + Pat<(Arith_SubIOp:$res + (Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x), + (Arith_ConstantOp APIntAttr:$c1)), + (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>; + +// subi(c1, subi(x, c0)) -> subi(c0 + c1, x) +def SubILHSSubConstantRHS : + Pat<(Arith_SubIOp:$res + (Arith_ConstantOp APIntAttr:$c1), + (Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0))), + (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>; + +// subi(c1, subi(c0, x)) -> addi(x, c1 - c0) +def SubILHSSubConstantLHS : + Pat<(Arith_SubIOp:$res + (Arith_ConstantOp APIntAttr:$c1), + (Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x)), + (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>; + +//===----------------------------------------------------------------------===// +// XOrIOp +//===----------------------------------------------------------------------===// + +// xori is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// not(cmpi(pred, a, b)) -> cmpi(~pred, a, b), where not(x) is xori(x, 1) +def InvertPredicate : NativeCodeCall<"invertPredicate($0)">; +def XOrINotCmpI : + Pat<(Arith_XOrIOp + (Arith_CmpIOp $pred, $a, $b), + (Arith_ConstantOp ConstantAttr)), + (Arith_CmpIOp (InvertPredicate $pred), $a, $b)>; + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +// index_cast(index_cast(x)) -> x, if dstType == srcType. +def IndexCastOfIndexCast : + Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)), + (replaceWithValue $x), + [(Constraint> $res, $x)]>; + +// index_cast(extsi(x)) -> index_cast(x) +def IndexCastOfExtSI : + Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>; + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +// bitcast(bitcast(x)) -> x +def BitcastOfBitcast : + Pat<(Arith_BitcastOp (Arith_BitcastOp $x)), (replaceWithValue $x)>; + +#endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp @@ -0,0 +1,37 @@ +//===- ArithmeticDialect.cpp - MLIR Arithmetic dialect implementation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::arith; + +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.cpp.inc" + +namespace { +/// This class defines the interface for handling inlining for arithmetic +/// dialect operations. +struct ArithmeticInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// All arithmetic dialect ops can be inlined. + bool isLegalToInline(Operation *, Region *, bool, + BlockAndValueMapping &) const final { + return true; + } +}; +} // end anonymous namespace + +void mlir::arith::ArithmeticDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" + >(); + addInterfaces(); +} diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -0,0 +1,737 @@ +//===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/CommonFolders.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; +using namespace mlir::arith; + +//===----------------------------------------------------------------------===// +// Pattern helpers +//===----------------------------------------------------------------------===// + +static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return builder.getIntegerAttr(res.getType(), + lhs.cast().getInt() + + rhs.cast().getInt()); +} + +static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return builder.getIntegerAttr(res.getType(), + lhs.cast().getInt() - + rhs.cast().getInt()); +} + +/// Invert an integer comparison predicate. +static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { + switch (pred) { + case arith::CmpIPredicate::eq: + return arith::CmpIPredicate::ne; + case arith::CmpIPredicate::ne: + return arith::CmpIPredicate::eq; + case arith::CmpIPredicate::slt: + return arith::CmpIPredicate::sge; + case arith::CmpIPredicate::sle: + return arith::CmpIPredicate::sgt; + case arith::CmpIPredicate::sgt: + return arith::CmpIPredicate::sle; + case arith::CmpIPredicate::sge: + return arith::CmpIPredicate::slt; + case arith::CmpIPredicate::ult: + return arith::CmpIPredicate::uge; + case arith::CmpIPredicate::ule: + return arith::CmpIPredicate::ugt; + case arith::CmpIPredicate::ugt: + return arith::CmpIPredicate::ule; + case arith::CmpIPredicate::uge: + return arith::CmpIPredicate::ult; + } + llvm_unreachable("unknown cmpi predicate kind"); +} + +static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { + return arith::CmpIPredicateAttr::get(pred.getContext(), + invertPredicate(pred.getValue())); +} + +//===----------------------------------------------------------------------===// +// TableGen'd canonicalization patterns +//===----------------------------------------------------------------------===// + +namespace { +#include "ArithmeticCanonicalization.inc" +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::AddIOp::fold(ArrayRef operands) { + // addi(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a + b; }); +} + +void arith::AddIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert( + context); +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::SubIOp::fold(ArrayRef operands) { + // subi(x,x) -> 0 + if (getOperand(0) == getOperand(1)) + return Builder(getContext()).getZeroAttr(getType()); + // subi(x,0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a - b; }); +} + +void arith::SubIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MulIOp::fold(ArrayRef operands) { + // muli(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + // muli(x, 1) -> x + if (matchPattern(rhs(), m_One())) + return getOperand(0); + // TODO: Handle the overflow case. + + // default folder + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// DivUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::DivUIOp::fold(ArrayRef operands) { + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (div0 || !b) { + div0 = true; + return a; + } + return a.udiv(b); + }); + + // Fold out division by one. Assumes all tensors of all ones are splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return div0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// DivSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::DivSIOp::fold(ArrayRef operands) { + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + return a.sdiv_ov(b, overflowOrDiv0); + }); + + // Fold out division by one. Assumes all tensors of all ones are splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// Ceil and floor division folding helpers +//===----------------------------------------------------------------------===// + +static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { + // Returns (a-1)/b + 1 + APInt one(a.getBitWidth(), 1, true); // Signed value 1. + APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); + return val.sadd_ov(one, overflow); +} + +//===----------------------------------------------------------------------===// +// CeilDivSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::CeilDivSIOp::fold(ArrayRef operands) { + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getZero(bits); + if (a.sgt(zero) && b.sgt(zero)) { + // Both positive, return ceil(a, b). + return signedCeilNonnegInputs(a, b, overflowOrDiv0); + } + if (a.slt(zero) && b.slt(zero)) { + // Both negative, return ceil(-a, -b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); + } + if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ( -a / b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt div = posA.sdiv_ov(b, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + } + // A is positive (or zero), b is negative, return - (a / -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt div = a.sdiv_ov(posB, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// FloorDivSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::FloorDivSIOp::fold(ArrayRef operands) { + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getZero(bits); + if (a.sge(zero) && b.sgt(zero)) { + // Both positive (or a is zero), return a / b. + return a.sdiv_ov(b, overflowOrDiv0); + } + if (a.sle(zero) && b.slt(zero)) { + // Both negative (or a is zero), return -a / -b. + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return posA.sdiv_ov(posB, overflowOrDiv0); + } + if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ceil(-a, b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + } + // A is positive, b is negative, return - ceil(a, -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// RemUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// RemSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::RemSIOp::fold(ArrayRef operands) { + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// AndIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::AndIOp::fold(ArrayRef operands) { + /// and(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// and(x, allOnes) -> x + APInt intValue; + if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) + return lhs(); + /// and(x, x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a & b; }); +} + +//===----------------------------------------------------------------------===// +// OrIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::OrIOp::fold(ArrayRef operands) { + /// or(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// or(x, x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a | b; }); +} + +//===----------------------------------------------------------------------===// +// XOrIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::XOrIOp::fold(ArrayRef operands) { + /// xor(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// xor(x, x) -> 0 + if (lhs() == rhs()) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a ^ b; }); +} + +void arith::XOrIOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::AddFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::SubFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MulFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// DivFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::DivFOp::fold(ArrayRef operands) { + return constFoldBinaryOp( + operands, [](APFloat a, APFloat b) { return a / b; }); +} + +//===----------------------------------------------------------------------===// +// Verifiers for integer and floating point extension/truncation ops +//===----------------------------------------------------------------------===// + +// Extend ops can only extend to a wider type. +template +static LogicalResult verifyExtOp(Op op) { + Type srcType = getElementTypeOrSelf(op.in().getType()); + Type dstType = getElementTypeOrSelf(op.getType()); + + if (srcType.cast().getWidth() >= dstType.cast().getWidth()) + return op.emitError("result type ") + << dstType << " must be wider than operand type " << srcType; + + return success(); +} + +// Truncate ops can only truncate to a shorter type. +template +static LogicalResult verifyTruncateOp(Op op) { + Type srcType = getElementTypeOrSelf(op.in().getType()); + Type dstType = getElementTypeOrSelf(op.getType()); + + if (srcType.cast().getWidth() <= dstType.cast().getWidth()) + return op.emitError("result type ") + << dstType << " must be shorter than operand type " << srcType; + + return success(); +} + +//===----------------------------------------------------------------------===// +// ExtUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ExtUIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) + return IntegerAttr::get( + getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); + + return {}; +} + +//===----------------------------------------------------------------------===// +// ExtSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ExtSIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) + return IntegerAttr::get( + getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); + + return {}; +} + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + assert(inputs.size() == 1 && outputs.size() == 1 && + "index_cast op expects one result and one result"); + + // Shape equivalence is guaranteed by op traits. + auto srcType = getElementTypeOrSelf(inputs.front()); + auto dstType = getElementTypeOrSelf(outputs.front()); + + return (srcType.isIndex() && dstType.isSignlessInteger()) || + (srcType.isSignlessInteger() && dstType.isIndex()); +} + +OpFoldResult arith::IndexCastOp::fold(ArrayRef operands) { + // index_cast(constant) -> constant + // A little hack because we go through int. Otherwise, the size of the + // constant might need to change. + if (auto value = operands[0].dyn_cast_or_null()) + return IntegerAttr::get(getType(), value.getInt()); + + return {}; +} + +void arith::IndexCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + assert(inputs.size() == 1 && outputs.size() == 1 && + "bitcast op expects one operand and one result"); + + // Shape equivalence is guaranteed by op traits. + auto srcType = getElementTypeOrSelf(inputs.front()); + auto dstType = getElementTypeOrSelf(outputs.front()); + + // Types are guarnateed to be integers or floats by constraints. + return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); +} + +OpFoldResult arith::BitcastOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "bitcast op expects 1 operand"); + + auto resType = getType(); + auto operand = operands[0]; + if (!operand) + return {}; + + /// Bitcast dense elements. + if (auto denseAttr = operand.dyn_cast_or_null()) + return denseAttr.bitcast(resType.cast().getElementType()); + /// Other shaped types unhandled. + if (resType.isa()) + return {}; + + /// Bitcast integer or float to integer or float. + APInt bits = operand.isa() + ? operand.cast().getValue().bitcastToAPInt() + : operand.cast().getValue(); + + if (auto resFloatType = resType.dyn_cast()) + return FloatAttr::get(resType, + APFloat(resFloatType.getFloatSemantics(), bits)); + return IntegerAttr::get(resType, bits); +} + +void arith::BitcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// Helpers for compare ops +//===----------------------------------------------------------------------===// + +/// Return the type of the same shape (scalar, vector or tensor) containing i1. +static Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = type.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), i1Type); + if (type.isa()) + return UnrankedTensorType::get(i1Type); + if (auto vectorType = type.dyn_cast()) + return VectorType::get(vectorType.getShape(), i1Type); + return i1Type; +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +/// comparison predicates. +bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, + const APInt &lhs, const APInt &rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs.eq(rhs); + case arith::CmpIPredicate::ne: + return lhs.ne(rhs); + case arith::CmpIPredicate::slt: + return lhs.slt(rhs); + case arith::CmpIPredicate::sle: + return lhs.sle(rhs); + case arith::CmpIPredicate::sgt: + return lhs.sgt(rhs); + case arith::CmpIPredicate::sge: + return lhs.sge(rhs); + case arith::CmpIPredicate::ult: + return lhs.ult(rhs); + case arith::CmpIPredicate::ule: + return lhs.ule(rhs); + case arith::CmpIPredicate::ugt: + return lhs.ugt(rhs); + case arith::CmpIPredicate::uge: + return lhs.uge(rhs); + } + llvm_unreachable("unknown cmpi predicate kind"); +} + +/// Returns true if the predicate is true for two equal operands. +static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { + switch (predicate) { + case arith::CmpIPredicate::eq: + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::uge: + return true; + case arith::CmpIPredicate::ne: + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ult: + case arith::CmpIPredicate::ugt: + return false; + } + llvm_unreachable("unknown cmpi predicate kind"); +} + +OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "cmpi takes two operands"); + + // cmpi(pred, x, x) + if (lhs() == rhs()) { + auto val = applyCmpPredicateToEqualOperands(getPredicate()); + return BoolAttr::get(getContext(), val); + } + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return BoolAttr::get(getContext(), val); +} + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +/// comparison predicates. +bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, + const APFloat &lhs, const APFloat &rhs) { + auto cmpResult = lhs.compare(rhs); + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return false; + case arith::CmpFPredicate::OEQ: + return cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::OGT: + return cmpResult == APFloat::cmpGreaterThan; + case arith::CmpFPredicate::OGE: + return cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::OLT: + return cmpResult == APFloat::cmpLessThan; + case arith::CmpFPredicate::OLE: + return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::ONE: + return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; + case arith::CmpFPredicate::ORD: + return cmpResult != APFloat::cmpUnordered; + case arith::CmpFPredicate::UEQ: + return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::UGT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan; + case arith::CmpFPredicate::UGE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::ULT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan; + case arith::CmpFPredicate::ULE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case arith::CmpFPredicate::UNE: + return cmpResult != APFloat::cmpEqual; + case arith::CmpFPredicate::UNO: + return cmpResult == APFloat::cmpUnordered; + case arith::CmpFPredicate::AlwaysTrue: + return true; + } + llvm_unreachable("unknown cmpf predicate kind"); +} + +OpFoldResult arith::CmpFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "cmpf takes two operands"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return BoolAttr::get(getContext(), val); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// TableGen'd enum attribute definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS ArithmeticCanonicalization.td) +mlir_tablegen(ArithmeticCanonicalization.inc -gen-rewriters) +add_public_tablegen_target(MLIRArithmeticCanonicalizationIncGen) + +add_mlir_dialect_library(MLIRArithmetic + ArithmeticOps.cpp + ArithmeticDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic + + DEPENDS + MLIRArithmeticOpsIncGen + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(Arithmetic) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) add_subdirectory(Async) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6673,6 +6673,118 @@ "include/mlir/Transforms/InliningUtils.h", ]) +td_library( + name = "ArithmeticOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td", + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ":SideEffectInterfacesTdFiles", + ":VectorInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "ArithmeticBaseIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=arith", + ], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=arith", + ], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.cpp.inc", + ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td", + deps = [":ArithmeticOpsTdFiles"], +) + +gentbl_cc_library( + name = "ArithmeticOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td", + deps = [ + ":ArithmeticOpsTdFiles", + ":CastInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "ArithmeticCanonicalizationIncGen", + strip_include_prefix = "include/mlir/Dialect/Arithmetic/IR", + tbl_outs = [ + ( + ["-gen-rewriters"], + "include/mlir/Dialect/Arithmetic/IR/ArithmeticCanonicalization.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td", + deps = [ + ":ArithmeticOpsTdFiles", + ":CastInterfacesTdFiles", + ":StdOpsTdFiles", + ], +) + +cc_library( + name = "ArithmeticDialect", + srcs = glob( + [ + "lib/Dialect/Arithmetic/IR/*.cpp", + "lib/Dialect/Arithmetic/IR/*.h", + ], + ), + hdrs = [ + "include/mlir/Dialect/Arithmetic/IR/Arithmetic.h", + "include/mlir/Transforms/InliningUtils.h", + ], + includes = ["include"], + deps = [ + ":ArithmeticBaseIncGen", + ":ArithmeticCanonicalizationIncGen", + ":ArithmeticOpsIncGen", + ":CommonFolders", + ":IR", + ":SideEffectInterfaces", + ":StandardOps", + ":Support", + ":VectorInterfaces", + "//llvm:Support", + ], +) + td_library( name = "MathOpsTdFiles", srcs = [