diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td @@ -80,6 +80,8 @@ /// Register all dialect operations. void registerOperations(); }]; + + let useDefaultAttributePrinterParser = 1; } #endif // INDEX_DIALECT diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexEnums.td b/mlir/include/mlir/Dialect/Index/IR/IndexEnums.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexEnums.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexEnums.td @@ -9,6 +9,36 @@ #ifndef INDEX_ENUMS #define INDEX_ENUMS +include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/IR/EnumAttr.td" +//===----------------------------------------------------------------------===// +// IndexCmpPredicate +//===----------------------------------------------------------------------===// + +def IndexCmpPredicate : I32EnumAttr< + "IndexCmpPredicate", "index comparison predicate kind", + [ + I32EnumAttrCase<"EQ", 0, "eq">, + I32EnumAttrCase<"NE", 1, "ne">, + I32EnumAttrCase<"SLT", 2, "slt">, + I32EnumAttrCase<"SLE", 3, "sle">, + I32EnumAttrCase<"SGT", 4, "sgt">, + I32EnumAttrCase<"SGE", 5, "sge">, + I32EnumAttrCase<"ULT", 6, "ult">, + I32EnumAttrCase<"ULE", 7, "ule">, + I32EnumAttrCase<"UGT", 8, "ugt">, + I32EnumAttrCase<"UGE", 9, "uge"> + ]> { + let cppNamespace = "::mlir::index"; + let genSpecializedAttr = 0; +} + +//===----------------------------------------------------------------------===// +// IndexCmpPredicateAttr +//===----------------------------------------------------------------------===// + +def IndexCmpPredicateAttr : EnumAttr< + IndexDialect, IndexCmpPredicate, "cmp_predicate">; + #endif // INDEX_ENUMS diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -9,6 +9,21 @@ #ifndef MLIR_DIALECT_INDEX_IR_INDEXOPS_H #define MLIR_DIALECT_INDEX_IR_INDEXOPS_H +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +//===----------------------------------------------------------------------===// +// Forward Declarations +//===----------------------------------------------------------------------===// + +namespace mlir::index { +enum class IndexCmpPredicate : uint32_t; +class IndexCmpPredicateAttr; +} // namespace mlir::index + //===----------------------------------------------------------------------===// // ODS-Generated Declarations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -10,6 +10,9 @@ #define INDEX_OPS include "mlir/Dialect/Index/IR/IndexDialect.td" +include "mlir/Dialect/Index/IR/IndexEnums.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" @@ -21,4 +24,431 @@ class IndexOp traits = []> : Op; +//===----------------------------------------------------------------------===// +// IndexBinaryOp +//===----------------------------------------------------------------------===// + +/// Base class for binary Index dialect operations. +class IndexBinaryOp traits = []> + : IndexOp { + let arguments = (ins Index:$lhs, Index:$rhs); + let results = (outs Index:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict"; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def Index_AddOp : IndexBinaryOp<"add"> { + let summary = "index addition"; + let description = [{ + The `index.add` operation takes two index values and computes their sum. + + Example: + + ```mlir + // c = a + b + %c = index.add %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +def Index_SubOp : IndexBinaryOp<"sub"> { + let summary = "index subtraction"; + let description = [{ + The `index.sub` operation takes two index values and computes the difference + of the first from the second operand. + + Example: + + ```mlir + // c = a - b + %c = index.sub %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +def Index_MulOp : IndexBinaryOp<"mul"> { + let summary = "index multiplication"; + let description = [{ + The `index.mul` operation takes two index values and computes their product. + + Example: + + ```mlir + // c = a * b + %c = index.mul %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// DivSOp +//===----------------------------------------------------------------------===// + +def Index_DivSOp : IndexBinaryOp<"divs"> { + let summary = "index signed division"; + let description = [{ + The `index.divs` operation takes two index values and computes their signed + quotient. Treats the leading bit as the sign and rounds towards zero, i.e. + `6 / -2 = -3`. + + Note: division by zero and signed division overflow are undefined behaviour. + + Example: + + ```mlir + // c = a / b + %c = index.divs %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// DivUOp +//===----------------------------------------------------------------------===// + +def Index_DivUOp : IndexBinaryOp<"divu"> { + let summary = "index unsigned division"; + let description = [{ + The `index.divu` operation takes two index values and computes their + unsigned quotient. Treats the leading bit as the most significant and rounds + towards zero, i.e. `6 / -2 = 0`. + + Note: division by zero is undefined behaviour. + + Example: + + ```mlir + // c = a / b + %c = index.divu %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CeilDivSOp +//===----------------------------------------------------------------------===// + +def Index_CeilDivSOp : IndexBinaryOp<"ceildivs"> { + let summary = "index signed ceil division"; + let description = [{ + The `index.ceildivs` operation takes two index values and computes their + signed quotient. Treats the leading bit as the sign and rounds towards + positive infinity, i.e. `7 / -2 = -3`. + + Note: division by zero and signed division overflow are undefined behaviour. + + Example: + + ```mlir + // c = ceil(a / b) + %c = index.ceildivs %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CeilDivUOp +//===----------------------------------------------------------------------===// + +def Index_CeilDivUOp : IndexBinaryOp<"ceildivu"> { + let summary = "index unsigned ceil division"; + let description = [{ + The `index.ceildivu` operation takes two index values and computes their + unsigned quotient. Treats the leading bit as the most significant and rounds + towards positive infinity, i.e. `6 / -2 = 1`. + + Note: division by zero is undefined behaviour. + + Example: + + ```mlir + // c = ceil(a / b) + %c = index.ceildivu %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// FloorDivSOp +//===----------------------------------------------------------------------===// + +def Index_FloorDivSOp : IndexBinaryOp<"floordivs"> { + let summary = "index signed floor division"; + let description = [{ + The `index.floordivs` operation takes two index values and computes their + signed quotient. Treats the leading bit as the sign and rounds towards + negative infinity, i.e. `5 / -2 = -3`. + + Note: division by zero and signed division overflow are undefined behaviour. + + Example: + + ```mlir + // c = floor(a / b) + %c = index.floordivs %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// RemSOp +//===----------------------------------------------------------------------===// + +def Index_RemSOp : IndexBinaryOp<"rems"> { + let summary = "index signed remainder"; + let description = [{ + The `index.rems` operation takes two index values and computes their signed + remainder. Treats the leading bit as the sign, i.e. `6 % -2 = 0`. + + Example: + + ```mlir + // c = a % b + %c = index.rems %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// RemUOp +//===----------------------------------------------------------------------===// + +def Index_RemUOp : IndexBinaryOp<"remu"> { + let summary = "index unsigned remainder"; + let description = [{ + The `index.remu` operation takes two index values and computes their + unsigned remainder. Treats the leading bit as the most significant, i.e. + `6 % -2 = 6`. + + Example: + + ```mlir + // c = a % b + %c = index.remu %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MaxSOp +//===----------------------------------------------------------------------===// + +def Index_MaxSOp : IndexBinaryOp<"maxs"> { + let summary = "index signed maximum"; + let description = [{ + The `index.maxs` operation takes two index values and computes their signed + maximum value. Treats the leading bit as the sign, i.e. `max(-2, 6) = 6`. + + Example: + + ```mlir + // c = max(a, b) + %c = index.maxs %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MaxUOp +//===----------------------------------------------------------------------===// + +def Index_MaxUOp : IndexBinaryOp<"maxu"> { + let summary = "index unsigned maximum"; + let description = [{ + The `index.maxu` operation takes two index values and computes their + unsigned maximum value. Treats the leading bit as the most significant, i.e. + `max(15, 6) = 15` or `max(-2, 6) = -2`. + + Example: + + ```mlir + // c = max(a, b) + %c = index.maxu %a, %b + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// CastSOp +//===----------------------------------------------------------------------===// + +def Index_CastSOp : IndexOp<"casts", + [DeclareOpInterfaceMethods]> { + let summary = "index signed cast"; + let description = [{ + The `index.casts` operation enables conversions between values of index type + and concrete fixed-width integer types. If casting to a wider integer, the + value is sign-extended. If casting to a narrower integer, the value is + truncated. + + Example: + + ```mlir + // Cast to i32 + %0 = index.casts %a : index to i32 + + // Cast from i64 + %1 = index.casts %b : i64 to index + ``` + }]; + + let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input); + let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// CastUOp +//===----------------------------------------------------------------------===// + +def Index_CastUOp : IndexOp<"castu", + [DeclareOpInterfaceMethods]> { + let summary = "index unsigned cast"; + let description = [{ + The `index.castu` operation enables conversions between values of index type + and concrete fixed-width integer types. If casting to a wider integer, the + value is zero-extended. If casting to a narrower integer, the value is + truncated. + + Example: + + ```mlir + // Cast to i32 + %0 = index.castu %a : index to i32 + + // Cast from i64 + %1 = index.castu %b : i64 to index + ``` + }]; + + let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input); + let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + +//===----------------------------------------------------------------------===// +// CmpOp +//===----------------------------------------------------------------------===// + +def Index_CmpOp : IndexOp<"cmp"> { + let summary = "index compare"; + let description = [{ + The `index.cmp` operation takes two index values and compares them according + to the comparison predicate and returns an `i1`. The following comparisons + are supported: + + - `eq`: equal + - `ne`: not equal + - `slt`: signed less than + - `sle`: signed less than or equal + - `sgt`: signed greater than + - `sge`: signed greater than or equal + - `ult`: unsigned less than + - `ule`: unsigned less than or equal + - `ugt`: unsigned greater than + - `uge`: unsigned greater than or equal + + The result is `1` if the comparison is true and `0` otherwise. + + Example: + + ```mlir + // Signed less than comparison. + %0 = index.cmp slt(%a, %b) + + // Unsigned greater than or equal comparison. + %1 = index.cmp uge(%a, %b) + + // Not equal comparison. + %2 = index.cmp ne(%a, %b) + ``` + }]; + + let arguments = (ins IndexCmpPredicateAttr:$pred, Index:$lhs, Index:$rhs); + let results = (outs I1:$result); + let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict"; +} + +//===----------------------------------------------------------------------===// +// SizeOfOp +//===----------------------------------------------------------------------===// + +def Index_SizeOfOp : IndexOp<"sizeof"> { + let summary = "size in bits of the index type"; + let description = [{ + The `index.sizeof` operation produces an index-typed SSA value equal to the + size in bits of the `index` type. For example, on 32-bit systems, the result + is `32 : index`, and on 64-bit systems, the result is `64 : index`. + + Example: + + ```mlir + %0 = index.sizeof + ``` + }]; + + let results = (outs Index:$result); + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +def Index_ConstantOp : IndexOp<"constant", [ConstantLike]> { + let summary = "index constant"; + let description = [{ + The `index.constant` operation produces an index-typed SSA value equal to + some index-typed integer constant. + + Example: + + ```mlir + %0 = index.constant 42 + ``` + }]; + + let arguments = (ins IndexAttr:$value); + let results = (outs Index:$result); + let assemblyFormat = "attr-dict $value"; + + let builders = [OpBuilder<(ins "int64_t":$value)>]; +} + +//===----------------------------------------------------------------------===// +// BoolConstantOp +//===----------------------------------------------------------------------===// + +def Index_BoolConstantOp : IndexOp<"bool.constant", [ConstantLike]> { + let summary = "boolean constant"; + let description = [{ + The `index.bool.constant` operation produces an bool-typed SSA value equal + to either `true` or `false`. + + This operation is used to materialize bool constants that arise when folding + `index.cmp`. + + Example: + + ```mlir + %0 = index.bool.constant true + ``` + }]; + + let arguments = (ins BoolAttr:$value); + let results = (outs I1:$result); + let assemblyFormat = "attr-dict $value"; +} + #endif // INDEX_OPS diff --git a/mlir/lib/Dialect/Index/IR/IndexAttrs.cpp b/mlir/lib/Dialect/Index/IR/IndexAttrs.cpp --- a/mlir/lib/Dialect/Index/IR/IndexAttrs.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexAttrs.cpp @@ -8,6 +8,9 @@ #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::index; diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -7,7 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" using namespace mlir; using namespace mlir::index; @@ -23,6 +26,22 @@ >(); } +//===----------------------------------------------------------------------===// +// CastSOp +//===----------------------------------------------------------------------===// + +bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { + return lhsTypes.front().isa() != rhsTypes.front().isa(); +} + +//===----------------------------------------------------------------------===// +// CastUOp +//===----------------------------------------------------------------------===// + +bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { + return lhsTypes.front().isa() != rhsTypes.front().isa(); +} + //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-errors.mlir b/mlir/test/Dialect/Index/index-errors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Index/index-errors.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func.func @invalid_cast(%a: index) { + // expected-error @below {{cast incompatible}} + %0 = index.casts %a : index to index + return +} + +// ----- + +func.func @invalid_cast(%a: i64) { + // expected-error @below {{cast incompatible}} + %0 = index.casts %a : i64 to i64 + return +} + +// ----- + +func.func @invalid_cast(%a: index) { + // expected-error @below {{cast incompatible}} + %0 = index.castu %a : index to index + return +} + +// ----- + +func.func @invalid_cast(%a: i64) { + // expected-error @below {{cast incompatible}} + %0 = index.castu %a : i64 to i64 + return +} diff --git a/mlir/test/Dialect/Index/index-ops.mlir b/mlir/test/Dialect/Index/index-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Index/index-ops.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: @binary_ops +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index +func.func @binary_ops(%a: index, %b: index) { + // CHECK-NEXT: index.add %[[A]], %[[B]] + %0 = index.add %a, %b + // CHECK-NEXT: index.sub %[[A]], %[[B]] + %1 = index.sub %a, %b + // CHECK-NEXT: index.mul %[[A]], %[[B]] + %2 = index.mul %a, %b + // CHECK-NEXT: index.divs %[[A]], %[[B]] + %3 = index.divs %a, %b + // CHECK-NEXT: index.divu %[[A]], %[[B]] + %4 = index.divu %a, %b + // CHECK-NEXT: index.ceildivs %[[A]], %[[B]] + %5 = index.ceildivs %a, %b + // CHECK-NEXT: index.ceildivu %[[A]], %[[B]] + %6 = index.ceildivu %a, %b + // CHECK-NEXT: index.floordivs %[[A]], %[[B]] + %7 = index.floordivs %a, %b + // CHECK-NEXT: index.rems %[[A]], %[[B]] + %8 = index.rems %a, %b + // CHECK-NEXT: index.remu %[[A]], %[[B]] + %9 = index.remu %a, %b + // CHECK-NEXT: index.maxs %[[A]], %[[B]] + %10 = index.maxs %a, %b + // CHECK-NEXT: index.maxu %[[A]], %[[B]] + %11 = index.maxu %a, %b + return +} + +// CHECK-LABEL: @cmp_op +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index +func.func @cmp_op(%a: index, %b: index) { + // CHECK-NEXT: index.cmp eq(%[[A]], %[[B]]) + %0 = index.cmp eq(%a, %b) + // CHECK-NEXT: index.cmp ne(%[[A]], %[[B]]) + %1 = index.cmp ne(%a, %b) + // CHECK-NEXT: index.cmp slt(%[[A]], %[[B]]) + %2 = index.cmp slt(%a, %b) + // CHECK-NEXT: index.cmp sle(%[[A]], %[[B]]) + %3 = index.cmp sle(%a, %b) + // CHECK-NEXT: index.cmp sgt(%[[A]], %[[B]]) + %4 = index.cmp sgt(%a, %b) + // CHECK-NEXT: index.cmp sge(%[[A]], %[[B]]) + %5 = index.cmp sge(%a, %b) + // CHECK-NEXT: index.cmp ult(%[[A]], %[[B]]) + %6 = index.cmp ult(%a, %b) + // CHECK-NEXT: index.cmp ule(%[[A]], %[[B]]) + %7 = index.cmp ule(%a, %b) + // CHECK-NEXT: index.cmp ugt(%[[A]], %[[B]]) + %8 = index.cmp ugt(%a, %b) + // CHECK-NEXT: index.cmp uge(%[[A]], %[[B]]) + %9 = index.cmp uge(%a, %b) + return +} + +// CHECK-LABEL: @sizeof_op +func.func @sizeof_op() { + // CHECK: index.sizeof + %0 = index.sizeof + return +} + +// CHECK-LABEL: @constant_op +func.func @constant_op() { + // CHECK-NEXT: index.constant 0 + %0 = index.constant 0 + // CHECK-NEXT: index.constant 1 + %1 = index.constant 1 + // CHECK-NEXT: index.constant 42 + %2 = index.constant 42 + return +} + +// CHECK-LABEL: @bool_constant_op +func.func @bool_constant_op() { + // CHECK-NEXT: index.bool.constant true + %0 = index.bool.constant true + // CHECK-NEXT: index.bool.constant false + %1 = index.bool.constant false + return +} + +// CHECK-LABEL: @cast_op +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: i32, %[[C:.*]]: i64 +func.func @cast_op(%a: index, %b: i32, %c: i64) { + // CHECK-NEXT: index.casts %[[A]] : index to i64 + %0 = index.casts %a : index to i64 + // CHECK-NEXT: index.casts %[[B]] : i32 to index + %1 = index.casts %b : i32 to index + // CHECK-NEXT: index.casts %[[C]] : i64 to index + %2 = index.casts %c : i64 to index + // CHECK-NEXT: index.castu %[[A]] : index to i64 + %3 = index.castu %a : index to i64 + // CHECK-NEXT: index.castu %[[B]] : i32 to index + %4 = index.castu %b : i32 to index + // CHECK-NEXT: index.castu %[[C]] : i64 to index + %5 = index.castu %c : i64 to index + return +}