diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -121,6 +121,10 @@ ArithmeticOp, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; +//===----------------------------------------------------------------------===// +// AbsFOp +//===----------------------------------------------------------------------===// + def AbsFOp : FloatUnaryOp<"absf"> { let summary = "floating point absolute-value operation"; let description = [{ @@ -131,16 +135,28 @@ }]; } +//===----------------------------------------------------------------------===// +// AddFOp +//===----------------------------------------------------------------------===// + def AddFOp : FloatArithmeticOp<"addf"> { let summary = "floating point addition operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + def AddIOp : IntArithmeticOp<"addi", [Commutative]> { let summary = "integer addition operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + def AllocOp : Std_Op<"alloc"> { let summary = "memory allocation operation"; let description = [{ @@ -213,11 +229,40 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + def AndOp : IntArithmeticOp<"and", [Commutative]> { let summary = "integer binary and"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// AssumeAlignmentOp +//===----------------------------------------------------------------------===// + +def AssumeAlignmentOp : Std_Op<"assume_alignment"> { + let summary = + "assertion that gives alignment information to the input memref"; + let description = [{ + The assume alignment operation takes a memref and a integer of alignment + value, and internally annotates the buffer with the given alignment. If + the buffer isn't aligned to the given alignment, the behavior is undefined. + + This operation doesn't affect the semantics of a correct program. It's for + optimization only, and the optimization is best-effort. + }]; + let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment); + let results = (outs); + + let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; +} + +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; @@ -281,6 +326,10 @@ }]; } +//===----------------------------------------------------------------------===// +// BranchOp +//===----------------------------------------------------------------------===// + def BranchOp : Std_Op<"br", [Terminator]> { let summary = "branch operation"; let description = [{ @@ -316,6 +365,10 @@ let assemblyFormat = "$dest attr-dict"; } +//===----------------------------------------------------------------------===// +// CallOp +//===----------------------------------------------------------------------===// + def CallOp : Std_Op<"call", [CallOpInterface]> { let summary = "call operation"; let description = [{ @@ -372,6 +425,10 @@ }]; } +//===----------------------------------------------------------------------===// +// CallIndirectOp +//===----------------------------------------------------------------------===// + def CallIndirectOp : Std_Op<"call_indirect", [ CallOpInterface, TypesMatchWith<"callee input types match argument types", @@ -423,6 +480,10 @@ let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } +//===----------------------------------------------------------------------===// +// CeilFOp +//===----------------------------------------------------------------------===// + def CeilFOp : FloatUnaryOp<"ceilf"> { let summary = "ceiling of the specified value"; let description = [{ @@ -433,6 +494,10 @@ }]; } +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + // The predicate indicates the type of the comparison to perform: // (un)orderedness, (in)equality and less/greater than (or equal to) as // well as predicates that are always true or false. @@ -519,6 +584,10 @@ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; @@ -594,6 +663,10 @@ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } +//===----------------------------------------------------------------------===// +// CondBranchOp +//===----------------------------------------------------------------------===// + def CondBranchOp : Std_Op<"cond_br", [Terminator]> { let summary = "conditional branch operation"; let description = [{ @@ -705,6 +778,10 @@ let assemblyFormat = "$condition `,` successors attr-dict"; } +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + def ConstantOp : Std_Op<"constant", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "constant"; @@ -727,6 +804,10 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// CopySignOp +//===----------------------------------------------------------------------===// + def CopySignOp : FloatArithmeticOp<"copysign"> { let summary = "A copysign operation"; let description = [{ @@ -738,6 +819,10 @@ }]; } +//===----------------------------------------------------------------------===// +// CosOp +//===----------------------------------------------------------------------===// + def CosOp : FloatUnaryOp<"cos"> { let summary = "cosine of the specified value"; let description = [{ @@ -748,6 +833,10 @@ }]; } +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// + def DeallocOp : Std_Op<"dealloc"> { let summary = "memory deallocation operation"; let description = [{ @@ -768,6 +857,10 @@ let assemblyFormat = "$memref attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// + def DimOp : Std_Op<"dim", [NoSideEffect]> { let summary = "dimension index operation"; let description = [{ @@ -800,24 +893,26 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// DivFOp +//===----------------------------------------------------------------------===// + def DivFOp : FloatArithmeticOp<"divf"> { let summary = "floating point division operation"; } -def SignedDivIOp : IntArithmeticOp<"divi_signed"> { - let summary = "signed integer division operation"; - let hasFolder = 1; -} - -def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { - let summary = "unsigned integer division operation"; - let hasFolder = 1; -} +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// def ExpOp : FloatUnaryOp<"exp"> { let summary = "base-e exponential of the specified value"; } +//===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect, TypesMatchWith<"result type matches element type of aggregate", @@ -862,22 +957,9 @@ }]; } -def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { - let summary = "cast between index and integer types"; - let description = [{ - Casts between integer scalars and 'index' scalars. Index is an integer of - platform-specific bit width. If casting to a wider integer, the value is - sign-extended. If casting to a narrower integer, the value is truncated. - }]; - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - }]; - - let hasFolder = 1; -} +//===----------------------------------------------------------------------===// +// FPExtOp +//===----------------------------------------------------------------------===// def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to wider floating-point"; @@ -896,6 +978,10 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// FPTruncOp +//===----------------------------------------------------------------------===// + def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ @@ -914,6 +1000,31 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { + let summary = "cast between index and integer types"; + let description = [{ + Casts between integer scalars and 'index' scalars. Index is an integer of + platform-specific bit width. If casting to a wider integer, the value is + sign-extended. If casting to a narrower integer, the value is truncated. + }]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + def LoadOp : Std_Op<"load", [TypesMatchWith<"result type matches element type of 'memref'", "memref", "result", @@ -956,6 +1067,10 @@ let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + def LogOp : FloatUnaryOp<"log"> { let summary = "base-e logarithm of the specified value"; } @@ -968,6 +1083,10 @@ let summary = "base-2 logarithm of the specified value"; } +//===----------------------------------------------------------------------===// +// MemRefCastOp +//===----------------------------------------------------------------------===// + def MemRefCastOp : CastOp<"memref_cast"> { let summary = "memref cast operation"; let description = [{ @@ -1022,16 +1141,28 @@ }]; } +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + def MulFOp : FloatArithmeticOp<"mulf"> { let summary = "floating point multiplication operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + def MulIOp : IntArithmeticOp<"muli", [Commutative]> { let summary = "integer multiplication operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// NegFOp +//===----------------------------------------------------------------------===// + def NegFOp : FloatUnaryOp<"negf"> { let summary = "floating point negation"; let description = [{ @@ -1042,11 +1173,19 @@ }]; } +//===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + def OrOp : IntArithmeticOp<"or", [Commutative]> { let summary = "integer binary or"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + def PrefetchOp : Std_Op<"prefetch"> { let summary = "prefetch operation"; let description = [{ @@ -1096,6 +1235,10 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + def RankOp : Std_Op<"rank", [NoSideEffect]> { let summary = "rank operation"; let description = [{ @@ -1118,29 +1261,17 @@ let assemblyFormat = "operands attr-dict `:` type(operands)"; } +//===----------------------------------------------------------------------===// +// RemFOp +//===----------------------------------------------------------------------===// + def RemFOp : FloatArithmeticOp<"remf"> { let summary = "floating point division remainder operation"; } -def RsqrtOp : FloatUnaryOp<"rsqrt"> { - let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; - let description = [{ - The `rsqrt` operation computes the reciprocal of the square root. It takes - one operand and returns one result of the same type. This type may be a - float scalar type, a vector whose element type is float, or a tensor of - floats. It has no standard attributes. - }]; -} - -def SignedRemIOp : IntArithmeticOp<"remi_signed"> { - let summary = "signed integer division remainder operation"; - let hasFolder = 1; -} - -def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { - let summary = "unsigned integer division remainder operation"; - let hasFolder = 1; -} +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { let summary = "return operation"; @@ -1164,6 +1295,24 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +//===----------------------------------------------------------------------===// +// RsqrtOp +//===----------------------------------------------------------------------===// + +def RsqrtOp : FloatUnaryOp<"rsqrt"> { + let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; + let description = [{ + The `rsqrt` operation computes the reciprocal of the square root. It takes + one operand and returns one result of the same type. This type may be a + float scalar type, a vector whose element type is float, or a tensor of + floats. It has no standard attributes. + }]; +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, AllTypesMatch<["true_value", "false_value", "result"]>, TypesMatchWith<"condition type matches i1 equivalent of result type", @@ -1209,6 +1358,64 @@ }]; } +//===----------------------------------------------------------------------===// +// ShiftLeftOp +//===----------------------------------------------------------------------===// + +def ShiftLeftOp : IntArithmeticOp<"shift_left"> { + let summary = "integer left-shift"; + let description = [{ + The shift_left operation shifts an integer value to the left by a variable + amount. The low order bits are filled with zeros. + + %1 = constant 5 : i8 // %1 is 0b00000101 + %2 = constant 3 : i8 + %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 + }]; +} + +//===----------------------------------------------------------------------===// +// SignedDivIOp +//===----------------------------------------------------------------------===// + +def SignedDivIOp : IntArithmeticOp<"divi_signed"> { + let summary = "signed integer division operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// SignedRemIOp +//===----------------------------------------------------------------------===// + +def SignedRemIOp : IntArithmeticOp<"remi_signed"> { + let summary = "signed integer division remainder operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// SignedShiftRightOp +//===----------------------------------------------------------------------===// + +def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { + let summary = "signed integer right-shift"; + let description = [{ + The shift_right_signed operation shifts an integer value to the right by + a variable amount. The integer is interpreted as signed. The high order + bits in the output are filled with copies of the most-significant bit + of the shifted value (which means that the sign of the value is preserved). + + %1 = constant 160 : i8 // %1 is 0b10100000 + %2 = constant 3 : i8 + %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 + %4 = constant 96 : i8 // %4 is 0b01100000 + %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 + }]; +} + +//===----------------------------------------------------------------------===// +// SignExtendIOp +//===----------------------------------------------------------------------===// + def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "integer sign extension operation"; @@ -1244,46 +1451,9 @@ }]; } -def ShiftLeftOp : IntArithmeticOp<"shift_left"> { - let summary = "integer left-shift"; - let description = [{ - The shift_left operation shifts an integer value to the left by a variable - amount. The low order bits are filled with zeros. - - %1 = constant 5 : i8 // %1 is 0b00000101 - %2 = constant 3 : i8 - %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 - }]; -} - -def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { - let summary = "signed integer right-shift"; - let description = [{ - The shift_right_signed operation shifts an integer value to the right by - a variable amount. The integer is interpreted as signed. The high order - bits in the output are filled with copies of the most-significant bit - of the shifted value (which means that the sign of the value is preserved). - - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = constant 96 : i8 // %4 is 0b01100000 - %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 - }]; -} - -def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { - let summary = "unsigned integer right-shift"; - let description = [{ - The shift_right_unsigned operation shifts an integer value to the right by - a variable amount. The integer is interpreted as unsigned. The high order - bits are always filled with zeros. - - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 - }]; -} +//===----------------------------------------------------------------------===// +// SIToFPOp +//===----------------------------------------------------------------------===// def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; @@ -1303,6 +1473,10 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + def SplatOp : Std_Op<"splat", [NoSideEffect, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", @@ -1343,6 +1517,24 @@ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +def SqrtOp : FloatUnaryOp<"sqrt"> { + let summary = "sqrt of the specified value"; + let description = [{ + The `sqrt` operation computes the square root. It takes one operand and + returns one result of the same type. This type may be a float scalar type, a + vector whose element type is float, or a tensor of floats. It has no standard + attributes. + }]; +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + def StoreOp : Std_Op<"store", [TypesMatchWith<"type of 'value' matches element type of 'memref'", "memref", "value", @@ -1389,16 +1581,28 @@ }]; } +//===----------------------------------------------------------------------===// +// SubFOp +//===----------------------------------------------------------------------===// + def SubFOp : FloatArithmeticOp<"subf"> { let summary = "floating point subtraction operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + def SubIOp : IntArithmeticOp<"subi"> { let summary = "integer subtraction operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SubViewOp +//===----------------------------------------------------------------------===// + def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let summary = "memref subview operation"; let description = [{ @@ -1565,15 +1769,9 @@ let hasCanonicalizer = 1; } -def SqrtOp : FloatUnaryOp<"sqrt"> { - let summary = "sqrt of the specified value"; - let description = [{ - The `sqrt` operation computes the square root. It takes one operand and - returns one result of the same type. This type may be a float scalar type, a - vector whose element type is float, or a tensor of floats. It has no standard - attributes. - }]; -} +//===----------------------------------------------------------------------===// +// TanhOp +//===----------------------------------------------------------------------===// def TanhOp : FloatUnaryOp<"tanh"> { let summary = "hyperbolic tangent of the specified value"; @@ -1585,6 +1783,10 @@ }]; } +//===----------------------------------------------------------------------===// +// TensorCastOp +//===----------------------------------------------------------------------===// + def TensorCastOp : CastOp<"tensor_cast"> { let summary = "tensor cast operation"; let description = [{ @@ -1611,6 +1813,10 @@ }]; } +//===----------------------------------------------------------------------===// +// TensorLoadOp +//===----------------------------------------------------------------------===// + def TensorLoadOp : Std_Op<"tensor_load", [SameOperandsAndResultShape, SameOperandsAndResultElementType, TypesMatchWith<"result type matches tensor equivalent of 'memref'", @@ -1648,6 +1854,10 @@ let assemblyFormat = "$memref attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// TensorStoreOp +//===----------------------------------------------------------------------===// + def TensorStoreOp : Std_Op<"tensor_store", [SameOperandsShape, SameOperandsElementType, TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'", @@ -1673,6 +1883,10 @@ let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// TruncateIOp +//===----------------------------------------------------------------------===// + def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "integer truncation operation"; let description = [{ @@ -1705,6 +1919,45 @@ }]; } +//===----------------------------------------------------------------------===// +// UnsignedDivIOp +//===----------------------------------------------------------------------===// + +def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { + let summary = "unsigned integer division operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// UnsignedRemIOp +//===----------------------------------------------------------------------===// + +def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { + let summary = "unsigned integer division remainder operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// UnsignedShiftRightOp +//===----------------------------------------------------------------------===// + +def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { + let summary = "unsigned integer right-shift"; + let description = [{ + The shift_right_unsigned operation shifts an integer value to the right by + a variable amount. The integer is interpreted as unsigned. The high order + bits are always filled with zeros. + + %1 = constant 160 : i8 // %1 is 0b10100000 + %2 = constant 3 : i8 + %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 + }]; +} + +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// + def ViewOp : Std_Op<"view", [NoSideEffect]> { let summary = "memref view operation"; let description = [{ @@ -1767,11 +2020,19 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// XOrOp +//===----------------------------------------------------------------------===// + def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// ZeroExtendIOp +//===----------------------------------------------------------------------===// + def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "integer zero extension operation"; let description = [{ @@ -1805,21 +2066,4 @@ }]; } -def AssumeAlignmentOp : Std_Op<"assume_alignment"> { - let summary = - "assertion that gives alignment information to the input memref"; - let description = [{ - The assume alignment operation takes a memref and a integer of alignment - value, and internally annotates the buffer with the given alignment. If - the buffer isn't aligned to the given alignment, the behavior is undefined. - - This operation doesn't affect the semantics of a correct program. It's for - optimization only, and the optimization is best-effort. - }]; - let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment); - let results = (outs); - - let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; -} - #endif // STANDARD_OPS diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -390,6 +390,68 @@ } //===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + +OpFoldResult AndOp::fold(ArrayRef operands) { + /// and(x, 0) -> 0 + if (matchPattern(rhs(), m_Zero())) + return rhs(); + /// and(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a & b; }); +} + +//===----------------------------------------------------------------------===// +// AssumeAlignmentOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AssumeAlignmentOp op) { + unsigned alignment = op.alignment().getZExtValue(); + if (!llvm::isPowerOf2_32(alignment)) + return op.emitOpError("alignment must be power of 2"); + return success(); +} + +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicRMWOp op) { + if (op.getMemRefType().getRank() != op.getNumOperands() - 2) + return op.emitOpError( + "expects the number of subscripts to be equal to memref rank"); + switch (op.kind()) { + case AtomicRMWKind::addf: + case AtomicRMWKind::maxf: + case AtomicRMWKind::minf: + case AtomicRMWKind::mulf: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "' expects a floating-point type"; + break; + case AtomicRMWKind::addi: + case AtomicRMWKind::maxs: + case AtomicRMWKind::maxu: + case AtomicRMWKind::mins: + case AtomicRMWKind::minu: + case AtomicRMWKind::muli: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "' expects an integer type"; + break; + default: + break; + } + return success(); +} + +//===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// @@ -1009,44 +1071,6 @@ return {}; } -//===----------------------------------------------------------------------===// -// SignedDivIOp -//===----------------------------------------------------------------------===// - -OpFoldResult SignedDivIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // Don't fold if it would overflow or if it requires a division by zero. - bool overflowOrDiv0 = false; - auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { - if (overflowOrDiv0 || !b) { - overflowOrDiv0 = true; - return a; - } - return a.sdiv_ov(b, overflowOrDiv0); - }); - return overflowOrDiv0 ? Attribute() : result; -} - -//===----------------------------------------------------------------------===// -// UnsignedDivIOp -//===----------------------------------------------------------------------===// - -OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // Don't fold if it would require a division by zero. - bool div0 = false; - auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { - if (div0 || !b) { - div0 = true; - return a; - } - return a.udiv(b); - }); - return div0 ? Attribute() : result; -} - // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- @@ -1290,6 +1314,36 @@ } //===----------------------------------------------------------------------===// +// FPExtOp +//===----------------------------------------------------------------------===// + +bool FPExtOp::areCastCompatible(Type a, Type b) { + if (auto fa = a.dyn_cast()) + if (auto fb = b.dyn_cast()) + return fa.getWidth() < fb.getWidth(); + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areCastCompatible(va.getElementType(), vb.getElementType()); + return false; +} + +//===----------------------------------------------------------------------===// +// FPTruncOp +//===----------------------------------------------------------------------===// + +bool FPTruncOp::areCastCompatible(Type a, Type b) { + if (auto fa = a.dyn_cast()) + if (auto fb = b.dyn_cast()) + return fa.getWidth() > fb.getWidth(); + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areCastCompatible(va.getElementType(), vb.getElementType()); + return false; +} + +//===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -1436,6 +1490,22 @@ } //===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + +OpFoldResult OrOp::fold(ArrayRef operands) { + /// or(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// or(x,x) -> x + if (lhs() == rhs()) + return rhs(); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a | b; }); +} + +//===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// @@ -1518,58 +1588,6 @@ } //===----------------------------------------------------------------------===// -// SignedRemIOp -//===----------------------------------------------------------------------===// - -OpFoldResult SignedRemIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "remi_signed takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); -} - -//===----------------------------------------------------------------------===// -// UnsignedRemIOp -//===----------------------------------------------------------------------===// - -OpFoldResult UnsignedRemIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "remi_unsigned takes two operands"); - - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); - - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; - - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); -} - -//===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -1594,15 +1612,6 @@ } //===----------------------------------------------------------------------===// -// SIToFPOp -//===----------------------------------------------------------------------===// - -// sitofp is applicable from integer types to float types. -bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); -} - -//===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// @@ -1644,37 +1653,91 @@ } //===----------------------------------------------------------------------===// -// SplatOp +// SignedDivIOp //===----------------------------------------------------------------------===// -static LogicalResult verify(SplatOp op) { - // TODO: we could replace this by a trait. - if (op.getOperand().getType() != - op.getType().cast().getElementType()) - return op.emitError("operand should be of elemental type of result type"); +OpFoldResult SignedDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); - return success(); + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + return a.sdiv_ov(b, overflowOrDiv0); + }); + return overflowOrDiv0 ? Attribute() : result; } -// Constant folding hook for SplatOp. -OpFoldResult SplatOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "splat takes one operand"); +//===----------------------------------------------------------------------===// +// SignedRemIOp +//===----------------------------------------------------------------------===// - auto constOperand = operands.front(); - if (!constOperand || - (!constOperand.isa() && !constOperand.isa())) +OpFoldResult SignedRemIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "remi_signed takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) return {}; + auto rhsValue = rhs.getValue(); - auto shapedType = getType().cast(); - assert(shapedType.getElementType() == constOperand.getType() && - "incorrect input attribute type for folding"); + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); - // SplatElementsAttr::get treats single value for second arg as being a splat. - return SplatElementsAttr::get(shapedType, {constOperand}); -} + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; -//===----------------------------------------------------------------------===// -// StoreOp + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); +} + +//===----------------------------------------------------------------------===// +// SIToFPOp +//===----------------------------------------------------------------------===// + +// sitofp is applicable from integer types to float types. +bool SIToFPOp::areCastCompatible(Type a, Type b) { + return a.isSignlessInteger() && b.isa(); +} + +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SplatOp op) { + // TODO: we could replace this by a trait. + if (op.getOperand().getType() != + op.getType().cast().getElementType()) + return op.emitError("operand should be of elemental type of result type"); + + return success(); +} + +// Constant folding hook for SplatOp. +OpFoldResult SplatOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "splat takes one operand"); + + auto constOperand = operands.front(); + if (!constOperand || + (!constOperand.isa() && !constOperand.isa())) + return {}; + + auto shapedType = getType().cast(); + assert(shapedType.getElementType() == constOperand.getType() && + "incorrect input attribute type for folding"); + + // SplatElementsAttr::get treats single value for second arg as being a splat. + return SplatElementsAttr::get(shapedType, {constOperand}); +} + +//===----------------------------------------------------------------------===// +// StoreOp //===----------------------------------------------------------------------===// static LogicalResult verify(StoreOp op) { @@ -1713,749 +1776,751 @@ } //===----------------------------------------------------------------------===// -// AndOp -//===----------------------------------------------------------------------===// - -OpFoldResult AndOp::fold(ArrayRef operands) { - /// and(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) - return rhs(); - /// and(x,x) -> x - if (lhs() == rhs()) - return rhs(); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a & b; }); -} - -//===----------------------------------------------------------------------===// -// OrOp -//===----------------------------------------------------------------------===// - -OpFoldResult OrOp::fold(ArrayRef operands) { - /// or(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); - /// or(x,x) -> x - if (lhs() == rhs()) - return rhs(); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a | b; }); -} - -//===----------------------------------------------------------------------===// -// XOrOp -//===----------------------------------------------------------------------===// - -OpFoldResult XOrOp::fold(ArrayRef operands) { - /// xor(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); - /// xor(x,x) -> 0 - if (lhs() == rhs()) - return Builder(getContext()).getZeroAttr(getType()); - - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a ^ b; }); -} - -//===----------------------------------------------------------------------===// -// TensorCastOp +// SubViewOp //===----------------------------------------------------------------------===// -bool TensorCastOp::areCastCompatible(Type a, Type b) { - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); - if (!aT || !bT) - return false; - - if (aT.getElementType() != bT.getElementType()) - return false; - - return succeeded(verifyCompatibleShape(aT, bT)); -} +// Returns a MemRefType with dynamic sizes and offset and the same stride as the +// `memRefType` passed as argument. +// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep +// sizes and offset static. +static Type inferSubViewResultType(MemRefType memRefType) { + auto rank = memRefType.getRank(); + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(memRefType, strides, offset); + assert(succeeded(res) && "SubViewOp expected strided memref type"); + (void)res; -OpFoldResult TensorCastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); + // Assume sizes and offset are fully dynamic for now until canonicalization + // occurs on the ranges. Typed strides don't change though. + offset = MemRefType::getDynamicStrideOrOffset(); + // Overwrite strides because verifier will not pass. + // TODO(b/144419106): don't force degrade the strides to fully dynamic. + for (auto &stride : strides) + stride = MemRefType::getDynamicStrideOrOffset(); + auto stridedLayout = + makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); + SmallVector sizes(rank, ShapedType::kDynamicSize); + return MemRefType::Builder(memRefType) + .setShape(sizes) + .setAffineMaps(stridedLayout); } -//===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op -//===----------------------------------------------------------------------===// - -static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) - return RankedTensorType::get(memref.getShape(), memref.getElementType()); - return NoneType::get(type.getContext()); +void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, Type resultType, + ArrayRef attrs) { + if (!resultType) + resultType = inferSubViewResultType(source.getType().cast()); + auto segmentAttr = b->getI32VectorAttr( + {1, static_cast(offsets.size()), static_cast(sizes.size()), + static_cast(strides.size())}); + build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); + result.addAttributes(attrs); } -//===----------------------------------------------------------------------===// -// TruncateIOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(TruncateIOp op) { - auto srcType = getElementTypeOrSelf(op.getOperand().getType()); - auto dstType = getElementTypeOrSelf(op.getType()); - - if (srcType.isa()) - return op.emitError() << srcType << " is not a valid operand type"; - if (dstType.isa()) - return op.emitError() << dstType << " is not a valid result type"; - - if (srcType.cast().getWidth() <= - dstType.cast().getWidth()) - return op.emitError("operand type ") - << srcType << " must be wider than result type " << dstType; - - return success(); +void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, + Value source) { + build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, + resultType); } -//===----------------------------------------------------------------------===// -// ViewOp -//===----------------------------------------------------------------------===// - -static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; - SmallVector offsetInfo; + SmallVector offsetsInfo; SmallVector sizesInfo; + SmallVector stridesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; - llvm::SMLoc offsetLoc; - if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || - parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + if (parser.parseOperand(srcInfo) || + parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { return failure(); + } - if (offsetInfo.size() > 1) - return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; + auto builder = parser.getBuilder(); + result.addAttribute( + SubViewOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())})); return failure( - parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } -static void print(OpAsmPrinter &p, ViewOp op) { - p << op.getOperationName() << ' ' << op.getOperand(0) << '['; - auto dynamicOffset = op.getDynamicOffset(); - if (dynamicOffset != nullptr) - p.printOperand(dynamicOffset); - p << "][" << op.getDynamicSizes() << ']'; - p.printOptionalAttrDict(op.getAttrs()); +static void print(OpAsmPrinter &p, SubViewOp op) { + p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets() + << "][" << op.sizes() << "][" << op.strides() << ']'; + + std::array elidedAttrs = { + SubViewOp::getOperandSegmentSizeAttr()}; + p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } -Value ViewOp::getDynamicOffset() { - int64_t offset; - SmallVector strides; - auto result = - succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); - assert(result); - if (result && offset == MemRefType::getDynamicStrideOrOffset()) - return getOperand(1); - return nullptr; -} +static LogicalResult verify(SubViewOp op) { + auto baseType = op.getBaseMemRefType().cast(); + auto subViewType = op.getType(); -static LogicalResult verifyDynamicStrides(MemRefType memrefType, - ArrayRef strides) { - ArrayRef shape = memrefType.getShape(); - unsigned rank = memrefType.getRank(); - assert(rank == strides.size()); - bool dynamicStrides = false; - for (int i = rank - 2; i >= 0; --i) { - // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. - if (ShapedType::isDynamic(shape[i + 1])) - dynamicStrides = true; - // If stride at dim 'i' is not dynamic, return error. - if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) - return failure(); + // The rank of the base and result subview must match. + if (baseType.getRank() != subViewType.getRank()) { + return op.emitError( + "expected rank of result type to match rank of base type "); } - return success(); -} - -static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0).getType().cast(); - auto viewType = op.getResult().getType().cast(); - - // The base memref should have identity layout map (or none). - if (baseType.getAffineMaps().size() > 1 || - (baseType.getAffineMaps().size() == 1 && - !baseType.getAffineMaps()[0].isIdentity())) - return op.emitError("unsupported map for base memref type ") << baseType; // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != viewType.getMemorySpace()) + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") - << baseType << " and view memref type " << viewType; + << baseType << " and subview memref type " << subViewType; + + // Verify that the base memref type has a strided layout map. + int64_t baseOffset; + SmallVector baseStrides; + if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) + return op.emitError("base type ") << subViewType << " is not strided"; // Verify that the result memref type has a strided layout map. - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(viewType, strides, offset))) - return op.emitError("result type ") << viewType << " is not strided"; + int64_t subViewOffset; + SmallVector subViewStrides; + if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) + return op.emitError("result type ") << subViewType << " is not strided"; - // Verify that we have the correct number of operands for the result type. - unsigned memrefOperandCount = 1; - unsigned numDynamicDims = viewType.getNumDynamicDims(); - unsigned dynamicOffsetCount = - offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; - if (op.getNumOperands() != - memrefOperandCount + numDynamicDims + dynamicOffsetCount) - return op.emitError("incorrect number of operands for type ") << viewType; + // Num offsets should either be zero or rank of memref. + if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { + return op.emitError("expected number of dynamic offsets specified to match " + "the rank of the result type ") + << subViewType; + } - // Verify dynamic strides symbols were added to correct dimensions based - // on dynamic sizes. - if (failed(verifyDynamicStrides(viewType, strides))) - return op.emitError("incorrect dynamic strides in view memref type ") - << viewType; + // Num sizes should either be zero or rank of memref. + if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { + return op.emitError("expected number of dynamic sizes specified to match " + "the rank of the result type ") + << subViewType; + } + + // Num strides should either be zero or rank of memref. + if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { + return op.emitError("expected number of dynamic strides specified to match " + "the rank of the result type ") + << subViewType; + } + + // Verify that if the shape of the subview type is static, then sizes are not + // dynamic values, and vice versa. + if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || + (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { + return op.emitError("invalid to specify dynamic sizes when subview result " + "type is statically shaped and viceversa"); + } + + // Verify that if dynamic sizes are specified, then the result memref type + // have full dynamic dimensions. + if (op.getNumSizes() > 0) { + if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { + return dim != ShapedType::kDynamicSize; + })) { + // TODO: This is based on the assumption that number of size arguments are + // either 0, or the rank of the result type. It is possible to have more + // fine-grained verification where only particular dimensions are + // dynamic. That probably needs further changes to the shape op + // specification. + return op.emitError("expected shape of result type to be fully dynamic " + "when sizes are specified"); + } + } + + // Verify that if dynamic offsets are specified or base memref has dynamic + // offset or base memref has dynamic strides, then the subview offset is + // dynamic. + if ((op.getNumOffsets() > 0 || + baseOffset == MemRefType::getDynamicStrideOrOffset() || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset())) && + subViewOffset != MemRefType::getDynamicStrideOrOffset()) { + return op.emitError( + "expected result memref layout map to have dynamic offset"); + } + + // For now, verify that if dynamic strides are specified, then all the result + // memref type have dynamic strides. + if (op.getNumStrides() > 0) { + if (llvm::any_of(subViewStrides, [](int64_t stride) { + return stride != MemRefType::getDynamicStrideOrOffset(); + })) { + return op.emitError("expected result type to have dynamic strides"); + } + } + + // If any of the base memref has dynamic stride, then the corresponding + // stride of the subview must also have dynamic stride. + assert(baseStrides.size() == subViewStrides.size()); + for (auto stride : enumerate(baseStrides)) { + if (stride.value() == MemRefType::getDynamicStrideOrOffset() && + subViewStrides[stride.index()] != + MemRefType::getDynamicStrideOrOffset()) { + return op.emitError( + "expected result type to have dynamic stride along a dimension if " + "the base memref type has dynamic stride along that dimension"); + } + } return success(); } -namespace { - -struct ViewOpShapeFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { + return os << "range " << range.offset << ":" << range.size << ":" + << range.stride; +} - PatternMatchResult matchAndRewrite(ViewOp viewOp, - PatternRewriter &rewriter) const override { - // Return if none of the operands are constants. - if (llvm::none_of(viewOp.getOperands(), [](Value operand) { - return matchPattern(operand, m_ConstantIndex()); - })) - return matchFailure(); +SmallVector SubViewOp::getRanges() { + SmallVector res; + unsigned rank = getType().getRank(); + res.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), + *(strides().begin() + i)}); + return res; +} - // Get result memref type. - auto memrefType = viewOp.getType(); - if (memrefType.getAffineMaps().size() > 1) - return matchFailure(); - auto map = memrefType.getAffineMaps().empty() - ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(), - rewriter.getContext()) - : memrefType.getAffineMaps()[0]; +LogicalResult +SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { + // If the strides are dynamic return failure. + if (getNumStrides()) + return failure(); - // Get offset from old memref view type 'memRefType'. - int64_t oldOffset; - SmallVector oldStrides; - if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) - return matchFailure(); + // When static, the stride operands can be retrieved by taking the strides of + // the result of the subview op, and dividing the strides of the base memref. + int64_t resultOffset, baseOffset; + SmallVector resultStrides, baseStrides; + if (failed( + getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || + llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || + failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) + return failure(); - SmallVector newOperands; + assert(static_cast(resultStrides.size()) == getType().getRank() && + baseStrides.size() == resultStrides.size() && + "base and result memrefs must have the same rank"); + assert(!llvm::is_contained(resultStrides, + MemRefType::getDynamicStrideOrOffset()) && + "strides of subview op must be static, when there are no dynamic " + "strides specified"); + staticStrides.resize(getType().getRank()); + for (auto resultStride : enumerate(resultStrides)) { + auto baseStride = baseStrides[resultStride.index()]; + // The result stride is expected to be a multiple of the base stride. Abort + // if that is not the case. + if (resultStride.value() < baseStride || + resultStride.value() % baseStride != 0) + return failure(); + staticStrides[resultStride.index()] = resultStride.value() / baseStride; + } + return success(); +} - // Fold dynamic offset operand if it is produced by a constant. - auto dynamicOffset = viewOp.getDynamicOffset(); - int64_t newOffset = oldOffset; - unsigned dynamicOffsetOperandCount = 0; - if (dynamicOffset != nullptr) { - auto *defOp = dynamicOffset.getDefiningOp(); - if (auto constantIndexOp = dyn_cast_or_null(defOp)) { - // Dynamic offset will be folded into the map. - newOffset = constantIndexOp.getValue(); - } else { - // Unable to fold dynamic offset. Add it to 'newOperands' list. - newOperands.push_back(dynamicOffset); - dynamicOffsetOperandCount = 1; - } - } +namespace { - // Fold any dynamic dim operands which are produced by a constant. - SmallVector newShapeConstants; - newShapeConstants.reserve(memrefType.getRank()); +/// Pattern to rewrite a subview op with constant size arguments. +class SubViewOpShapeFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart(); - unsigned rank = memrefType.getRank(); - for (unsigned dim = 0, e = rank; dim < e; ++dim) { - int64_t dimSize = memrefType.getDimSize(dim); - // If this is already static dimension, keep it. - if (!ShapedType::isDynamic(dimSize)) { - newShapeConstants.push_back(dimSize); - continue; - } - auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp(); - if (auto constantIndexOp = dyn_cast_or_null(defOp)) { - // Dynamic shape dimension will be folded. - newShapeConstants.push_back(constantIndexOp.getValue()); - } else { - // Dynamic shape dimension not folded; copy operand from old memref. - newShapeConstants.push_back(dimSize); - newOperands.push_back(viewOp.getOperand(dynamicDimPos)); - } - dynamicDimPos++; + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + MemRefType subViewType = subViewOp.getType(); + // Follow all or nothing approach for shapes for now. If all the operands + // for sizes are constants then fold it into the type of the result memref. + if (subViewType.hasStaticShape() || + llvm::any_of(subViewOp.sizes(), [](Value operand) { + return !matchPattern(operand, m_ConstantIndex()); + })) { + return matchFailure(); } - - // Compute new strides based on 'newShapeConstants'. - SmallVector newStrides(rank); - newStrides[rank - 1] = 1; - bool dynamicStrides = false; - for (int i = rank - 2; i >= 0; --i) { - if (ShapedType::isDynamic(newShapeConstants[i + 1])) - dynamicStrides = true; - if (dynamicStrides) - newStrides[i] = MemRefType::getDynamicStrideOrOffset(); - else - newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1]; + SmallVector staticShape(subViewOp.getNumSizes()); + for (auto size : llvm::enumerate(subViewOp.sizes())) { + auto defOp = size.value().getDefiningOp(); + assert(defOp); + staticShape[size.index()] = cast(defOp).getValue(); } + MemRefType newMemRefType = + MemRefType::Builder(subViewType).setShape(staticShape); + auto newSubViewOp = rewriter.create( + subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), + ArrayRef(), subViewOp.strides(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); + return matchSuccess(); + } +}; - // Regenerate strided layout map with 'newStrides' and 'newOffset'. - map = makeStridedLinearLayoutMap(newStrides, newOffset, - rewriter.getContext()); +// Pattern to rewrite a subview op with constant stride arguments. +class SubViewOpStrideFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - // Create new memref type with constant folded dims and/or offset/strides. - MemRefType newMemRefType = MemRefType::Builder(memrefType) - .setShape(newShapeConstants) - .setAffineMaps({map}); - (void)dynamicOffsetOperandCount; // unused in opt mode - assert(static_cast(newOperands.size()) == - dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + if (subViewOp.getNumStrides() == 0) { + return matchFailure(); + } + // Follow all or nothing approach for strides for now. If all the operands + // for strides are constants then fold it into the strides of the result + // memref. + int64_t baseOffset, resultOffset; + SmallVector baseStrides, resultStrides; + MemRefType subViewType = subViewOp.getType(); + if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, + baseOffset)) || + failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset()) || + llvm::any_of(subViewOp.strides(), [](Value stride) { + return !matchPattern(stride, m_ConstantIndex()); + })) { + return matchFailure(); + } - // Create new ViewOp. - auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, - viewOp.getOperand(0), newOperands); - // Insert a cast so we have the same type as the old memref type. - rewriter.replaceOpWithNewOp(viewOp, newViewOp, - viewOp.getType()); + SmallVector staticStrides(subViewOp.getNumStrides()); + for (auto stride : llvm::enumerate(subViewOp.strides())) { + auto defOp = stride.value().getDefiningOp(); + assert(defOp); + assert(baseStrides[stride.index()] > 0); + staticStrides[stride.index()] = + cast(defOp).getValue() * baseStrides[stride.index()]; + } + AffineMap layoutMap = makeStridedLinearLayoutMap( + staticStrides, resultOffset, rewriter.getContext()); + MemRefType newMemRefType = + MemRefType::Builder(subViewType).setAffineMaps(layoutMap); + auto newSubViewOp = rewriter.create( + subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), + subViewOp.sizes(), ArrayRef(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; -struct ViewOpMemrefCastFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Pattern to rewrite a subview op with constant offset arguments. +class SubViewOpOffsetFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ViewOp viewOp, + PatternMatchResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - Value memrefOperand = viewOp.getOperand(0); - MemRefCastOp memrefCastOp = - dyn_cast_or_null(memrefOperand.getDefiningOp()); - if (!memrefCastOp) + if (subViewOp.getNumOffsets() == 0) { return matchFailure(); - Value allocOperand = memrefCastOp.getOperand(); - AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); - if (!allocOp) + } + // Follow all or nothing approach for offsets for now. If all the operands + // for offsets are constants then fold it into the offset of the result + // memref. + int64_t baseOffset, resultOffset; + SmallVector baseStrides, resultStrides; + MemRefType subViewType = subViewOp.getType(); + if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, + baseOffset)) || + failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || + llvm::is_contained(baseStrides, + MemRefType::getDynamicStrideOrOffset()) || + baseOffset == MemRefType::getDynamicStrideOrOffset() || + llvm::any_of(subViewOp.offsets(), [](Value stride) { + return !matchPattern(stride, m_ConstantIndex()); + })) { return matchFailure(); - rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, - viewOp.operands()); + } + + auto staticOffset = baseOffset; + for (auto offset : llvm::enumerate(subViewOp.offsets())) { + auto defOp = offset.value().getDefiningOp(); + assert(defOp); + assert(baseStrides[offset.index()] > 0); + staticOffset += + cast(defOp).getValue() * baseStrides[offset.index()]; + } + + AffineMap layoutMap = makeStridedLinearLayoutMap( + resultStrides, staticOffset, rewriter.getContext()); + MemRefType newMemRefType = + MemRefType::Builder(subViewType).setAffineMaps(layoutMap); + auto newSubViewOp = rewriter.create( + subViewOp.getLoc(), subViewOp.source(), ArrayRef(), + subViewOp.sizes(), subViewOp.strides(), newMemRefType); + // Insert a memref_cast for compatibility of the uses of the op. + rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, + subViewOp.getType()); return matchSuccess(); } }; } // end anonymous namespace -void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); +void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); } //===----------------------------------------------------------------------===// -// SubViewOp +// TensorCastOp //===----------------------------------------------------------------------===// -// Returns a MemRefType with dynamic sizes and offset and the same stride as the -// `memRefType` passed as argument. -// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep -// sizes and offset static. -static Type inferSubViewResultType(MemRefType memRefType) { - auto rank = memRefType.getRank(); - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); - assert(succeeded(res) && "SubViewOp expected strided memref type"); - (void)res; - - // Assume sizes and offset are fully dynamic for now until canonicalization - // occurs on the ranges. Typed strides don't change though. - offset = MemRefType::getDynamicStrideOrOffset(); - // Overwrite strides because verifier will not pass. - // TODO(b/144419106): don't force degrade the strides to fully dynamic. - for (auto &stride : strides) - stride = MemRefType::getDynamicStrideOrOffset(); - auto stridedLayout = - makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); - SmallVector sizes(rank, ShapedType::kDynamicSize); - return MemRefType::Builder(memRefType) - .setShape(sizes) - .setAffineMaps(stridedLayout); -} +bool TensorCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + if (!aT || !bT) + return false; -void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source, - ValueRange offsets, ValueRange sizes, - ValueRange strides, Type resultType, - ArrayRef attrs) { - if (!resultType) - resultType = inferSubViewResultType(source.getType().cast()); - auto segmentAttr = b->getI32VectorAttr( - {1, static_cast(offsets.size()), static_cast(sizes.size()), - static_cast(strides.size())}); - build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); - result.addAttributes(attrs); -} + if (aT.getElementType() != bT.getElementType()) + return false; -void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType, - Value source) { - build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, - resultType); + return succeeded(verifyCompatibleShape(aT, bT)); } -static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType srcInfo; - SmallVector offsetsInfo; - SmallVector sizesInfo; - SmallVector stridesInfo; - auto indexType = parser.getBuilder().getIndexType(); - Type srcType, dstType; - if (parser.parseOperand(srcInfo) || - parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || - parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) { - return failure(); - } - - auto builder = parser.getBuilder(); - result.addAttribute( - SubViewOp::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr({1, static_cast(offsetsInfo.size()), - static_cast(sizesInfo.size()), - static_cast(stridesInfo.size())})); - - return failure( - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.resolveOperands(offsetsInfo, indexType, result.operands) || - parser.resolveOperands(sizesInfo, indexType, result.operands) || - parser.resolveOperands(stridesInfo, indexType, result.operands) || - parser.parseKeywordType("to", dstType) || - parser.addTypeToList(dstType, result.types)); +OpFoldResult TensorCastOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); } -static void print(OpAsmPrinter &p, SubViewOp op) { - p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets() - << "][" << op.sizes() << "][" << op.strides() << ']'; +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op +//===----------------------------------------------------------------------===// - std::array elidedAttrs = { - SubViewOp::getOperandSegmentSizeAttr()}; - p.printOptionalAttrDict(op.getAttrs(), elidedAttrs); - p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + return NoneType::get(type.getContext()); } -static LogicalResult verify(SubViewOp op) { - auto baseType = op.getBaseMemRefType().cast(); - auto subViewType = op.getType(); - - // The rank of the base and result subview must match. - if (baseType.getRank() != subViewType.getRank()) { - return op.emitError( - "expected rank of result type to match rank of base type "); - } - - // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != subViewType.getMemorySpace()) - return op.emitError("different memory spaces specified for base memref " - "type ") - << baseType << " and subview memref type " << subViewType; - - // Verify that the base memref type has a strided layout map. - int64_t baseOffset; - SmallVector baseStrides; - if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) - return op.emitError("base type ") << subViewType << " is not strided"; - - // Verify that the result memref type has a strided layout map. - int64_t subViewOffset; - SmallVector subViewStrides; - if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) - return op.emitError("result type ") << subViewType << " is not strided"; - - // Num offsets should either be zero or rank of memref. - if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { - return op.emitError("expected number of dynamic offsets specified to match " - "the rank of the result type ") - << subViewType; - } - - // Num sizes should either be zero or rank of memref. - if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { - return op.emitError("expected number of dynamic sizes specified to match " - "the rank of the result type ") - << subViewType; - } - - // Num strides should either be zero or rank of memref. - if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { - return op.emitError("expected number of dynamic strides specified to match " - "the rank of the result type ") - << subViewType; - } - - // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and vice versa. - if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || - (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { - return op.emitError("invalid to specify dynamic sizes when subview result " - "type is statically shaped and viceversa"); - } +//===----------------------------------------------------------------------===// +// TruncateIOp +//===----------------------------------------------------------------------===// - // Verify that if dynamic sizes are specified, then the result memref type - // have full dynamic dimensions. - if (op.getNumSizes() > 0) { - if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { - return dim != ShapedType::kDynamicSize; - })) { - // TODO: This is based on the assumption that number of size arguments are - // either 0, or the rank of the result type. It is possible to have more - // fine-grained verification where only particular dimensions are - // dynamic. That probably needs further changes to the shape op - // specification. - return op.emitError("expected shape of result type to be fully dynamic " - "when sizes are specified"); - } - } +static LogicalResult verify(TruncateIOp op) { + auto srcType = getElementTypeOrSelf(op.getOperand().getType()); + auto dstType = getElementTypeOrSelf(op.getType()); - // Verify that if dynamic offsets are specified or base memref has dynamic - // offset or base memref has dynamic strides, then the subview offset is - // dynamic. - if ((op.getNumOffsets() > 0 || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset())) && - subViewOffset != MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result memref layout map to have dynamic offset"); - } + if (srcType.isa()) + return op.emitError() << srcType << " is not a valid operand type"; + if (dstType.isa()) + return op.emitError() << dstType << " is not a valid result type"; - // For now, verify that if dynamic strides are specified, then all the result - // memref type have dynamic strides. - if (op.getNumStrides() > 0) { - if (llvm::any_of(subViewStrides, [](int64_t stride) { - return stride != MemRefType::getDynamicStrideOrOffset(); - })) { - return op.emitError("expected result type to have dynamic strides"); - } - } + if (srcType.cast().getWidth() <= + dstType.cast().getWidth()) + return op.emitError("operand type ") + << srcType << " must be wider than result type " << dstType; - // If any of the base memref has dynamic stride, then the corresponding - // stride of the subview must also have dynamic stride. - assert(baseStrides.size() == subViewStrides.size()); - for (auto stride : enumerate(baseStrides)) { - if (stride.value() == MemRefType::getDynamicStrideOrOffset() && - subViewStrides[stride.index()] != - MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result type to have dynamic stride along a dimension if " - "the base memref type has dynamic stride along that dimension"); - } - } return success(); } -raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) { - return os << "range " << range.offset << ":" << range.size << ":" - << range.stride; +//===----------------------------------------------------------------------===// +// UnsignedDivIOp +//===----------------------------------------------------------------------===// + +OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (div0 || !b) { + div0 = true; + return a; + } + return a.udiv(b); + }); + return div0 ? Attribute() : result; } -SmallVector SubViewOp::getRanges() { - SmallVector res; - unsigned rank = getType().getRank(); - res.reserve(rank); - for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), - *(strides().begin() + i)}); - return res; +//===----------------------------------------------------------------------===// +// UnsignedRemIOp +//===----------------------------------------------------------------------===// + +OpFoldResult UnsignedRemIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "remi_unsigned takes two operands"); + + auto rhs = operands.back().dyn_cast_or_null(); + if (!rhs) + return {}; + auto rhsValue = rhs.getValue(); + + // x % 1 = 0 + if (rhsValue.isOneValue()) + return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + + // Don't fold if it requires division by zero. + if (rhsValue.isNullValue()) + return {}; + + auto lhs = operands.front().dyn_cast_or_null(); + if (!lhs) + return {}; + return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } -LogicalResult -SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { - // If the strides are dynamic return failure. - if (getNumStrides()) - return failure(); +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// - // When static, the stride operands can be retrieved by taking the strides of - // the result of the subview op, and dividing the strides of the base memref. - int64_t resultOffset, baseOffset; - SmallVector resultStrides, baseStrides; - if (failed( - getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || - llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) +static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector offsetInfo; + SmallVector sizesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + llvm::SMLoc offsetLoc; + if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) return failure(); - assert(static_cast(resultStrides.size()) == getType().getRank() && - baseStrides.size() == resultStrides.size() && - "base and result memrefs must have the same rank"); - assert(!llvm::is_contained(resultStrides, - MemRefType::getDynamicStrideOrOffset()) && - "strides of subview op must be static, when there are no dynamic " - "strides specified"); - staticStrides.resize(getType().getRank()); - for (auto resultStride : enumerate(resultStrides)) { - auto baseStride = baseStrides[resultStride.index()]; - // The result stride is expected to be a multiple of the base stride. Abort - // if that is not the case. - if (resultStride.value() < baseStride || - resultStride.value() % baseStride != 0) + if (offsetInfo.size() > 1) + return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; + + return failure( + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, ViewOp op) { + p << op.getOperationName() << ' ' << op.getOperand(0) << '['; + auto dynamicOffset = op.getDynamicOffset(); + if (dynamicOffset != nullptr) + p.printOperand(dynamicOffset); + p << "][" << op.getDynamicSizes() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +} + +Value ViewOp::getDynamicOffset() { + int64_t offset; + SmallVector strides; + auto result = + succeeded(mlir::getStridesAndOffset(getType(), strides, offset)); + assert(result); + if (result && offset == MemRefType::getDynamicStrideOrOffset()) + return getOperand(1); + return nullptr; +} + +static LogicalResult verifyDynamicStrides(MemRefType memrefType, + ArrayRef strides) { + ArrayRef shape = memrefType.getShape(); + unsigned rank = memrefType.getRank(); + assert(rank == strides.size()); + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(shape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) return failure(); - staticStrides[resultStride.index()] = resultStride.value() / baseStride; } return success(); } -//===----------------------------------------------------------------------===// -// AssumeAlignmentOp -//===----------------------------------------------------------------------===// +static LogicalResult verify(ViewOp op) { + auto baseType = op.getOperand(0).getType().cast(); + auto viewType = op.getResult().getType().cast(); -static LogicalResult verify(AssumeAlignmentOp op) { - unsigned alignment = op.alignment().getZExtValue(); - if (!llvm::isPowerOf2_32(alignment)) - return op.emitOpError("alignment must be power of 2"); + // The base memref should have identity layout map (or none). + if (baseType.getAffineMaps().size() > 1 || + (baseType.getAffineMaps().size() == 1 && + !baseType.getAffineMaps()[0].isIdentity())) + return op.emitError("unsupported map for base memref type ") << baseType; + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != viewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and view memref type " << viewType; + + // Verify that the result memref type has a strided layout map. + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(viewType, strides, offset))) + return op.emitError("result type ") << viewType << " is not strided"; + + // Verify that we have the correct number of operands for the result type. + unsigned memrefOperandCount = 1; + unsigned numDynamicDims = viewType.getNumDynamicDims(); + unsigned dynamicOffsetCount = + offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; + if (op.getNumOperands() != + memrefOperandCount + numDynamicDims + dynamicOffsetCount) + return op.emitError("incorrect number of operands for type ") << viewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + if (failed(verifyDynamicStrides(viewType, strides))) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; return success(); } namespace { -/// Pattern to rewrite a subview op with constant size arguments. -class SubViewOpShapeFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; +struct ViewOpShapeFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { - MemRefType subViewType = subViewOp.getType(); - // Follow all or nothing approach for shapes for now. If all the operands - // for sizes are constants then fold it into the type of the result memref. - if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.sizes(), [](Value operand) { - return !matchPattern(operand, m_ConstantIndex()); - })) { + // Return if none of the operands are constants. + if (llvm::none_of(viewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) return matchFailure(); - } - SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : llvm::enumerate(subViewOp.sizes())) { - auto defOp = size.value().getDefiningOp(); - assert(defOp); - staticShape[size.index()] = cast(defOp).getValue(); - } - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setShape(staticShape); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - ArrayRef(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); - return matchSuccess(); - } -}; -// Pattern to rewrite a subview op with constant stride arguments. -class SubViewOpStrideFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; + // Get result memref type. + auto memrefType = viewOp.getType(); + if (memrefType.getAffineMaps().size() > 1) + return matchFailure(); + auto map = memrefType.getAffineMaps().empty() + ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(), + rewriter.getContext()) + : memrefType.getAffineMaps()[0]; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { - if (subViewOp.getNumStrides() == 0) { + // Get offset from old memref view type 'memRefType'. + int64_t oldOffset; + SmallVector oldStrides; + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return matchFailure(); + + SmallVector newOperands; + + // Fold dynamic offset operand if it is produced by a constant. + auto dynamicOffset = viewOp.getDynamicOffset(); + int64_t newOffset = oldOffset; + unsigned dynamicOffsetOperandCount = 0; + if (dynamicOffset != nullptr) { + auto *defOp = dynamicOffset.getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic offset will be folded into the map. + newOffset = constantIndexOp.getValue(); + } else { + // Unable to fold dynamic offset. Add it to 'newOperands' list. + newOperands.push_back(dynamicOffset); + dynamicOffsetOperandCount = 1; + } } - // Follow all or nothing approach for strides for now. If all the operands - // for strides are constants then fold it into the strides of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.strides(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { - return matchFailure(); + + // Fold any dynamic dim operands which are produced by a constant. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + + unsigned dynamicDimPos = viewOp.getDynamicSizesOperandStart(); + unsigned rank = memrefType.getRank(); + for (unsigned dim = 0, e = rank; dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (!ShapedType::isDynamic(dimSize)) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(dimSize); + newOperands.push_back(viewOp.getOperand(dynamicDimPos)); + } + dynamicDimPos++; } - SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : llvm::enumerate(subViewOp.strides())) { - auto defOp = stride.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[stride.index()] > 0); - staticStrides[stride.index()] = - cast(defOp).getValue() * baseStrides[stride.index()]; + // Compute new strides based on 'newShapeConstants'. + SmallVector newStrides(rank); + newStrides[rank - 1] = 1; + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + if (ShapedType::isDynamic(newShapeConstants[i + 1])) + dynamicStrides = true; + if (dynamicStrides) + newStrides[i] = MemRefType::getDynamicStrideOrOffset(); + else + newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1]; } - AffineMap layoutMap = makeStridedLinearLayoutMap( - staticStrides, resultOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - subViewOp.sizes(), ArrayRef(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + + // Regenerate strided layout map with 'newStrides' and 'newOffset'. + map = makeStridedLinearLayoutMap(newStrides, newOffset, + rewriter.getContext()); + + // Create new memref type with constant folded dims and/or offset/strides. + MemRefType newMemRefType = MemRefType::Builder(memrefType) + .setShape(newShapeConstants) + .setAffineMaps({map}); + (void)dynamicOffsetOperandCount; // unused in opt mode + assert(static_cast(newOperands.size()) == + dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); + + // Create new ViewOp. + auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), newOperands); + // Insert a cast so we have the same type as the old memref type. + rewriter.replaceOpWithNewOp(viewOp, newViewOp, + viewOp.getType()); return matchSuccess(); } }; -// Pattern to rewrite a subview op with constant offset arguments. -class SubViewOpOffsetFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; +struct ViewOpMemrefCastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, + PatternMatchResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { - if (subViewOp.getNumOffsets() == 0) { + Value memrefOperand = viewOp.getOperand(0); + MemRefCastOp memrefCastOp = + dyn_cast_or_null(memrefOperand.getDefiningOp()); + if (!memrefCastOp) return matchFailure(); - } - // Follow all or nothing approach for offsets for now. If all the operands - // for offsets are constants then fold it into the offset of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.offsets(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { + Value allocOperand = memrefCastOp.getOperand(); + AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); + if (!allocOp) return matchFailure(); - } - - auto staticOffset = baseOffset; - for (auto offset : llvm::enumerate(subViewOp.offsets())) { - auto defOp = offset.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[offset.index()] > 0); - staticOffset += - cast(defOp).getValue() * baseStrides[offset.index()]; - } - - AffineMap layoutMap = makeStridedLinearLayoutMap( - resultStrides, staticOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), ArrayRef(), - subViewOp.sizes(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, + viewOp.operands()); return matchSuccess(); } }; } // end anonymous namespace -void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); +void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// XOrOp +//===----------------------------------------------------------------------===// + +OpFoldResult XOrOp::fold(ArrayRef operands) { + /// xor(x, 0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); + /// xor(x,x) -> 0 + if (lhs() == rhs()) + return Builder(getContext()).getZeroAttr(getType()); + + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a ^ b; }); } //===----------------------------------------------------------------------===// @@ -2480,71 +2545,6 @@ } //===----------------------------------------------------------------------===// -// FPExtOp -//===----------------------------------------------------------------------===// - -bool FPExtOp::areCastCompatible(Type a, Type b) { - if (auto fa = a.dyn_cast()) - if (auto fb = b.dyn_cast()) - return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; -} - -//===----------------------------------------------------------------------===// -// FPTruncOp -//===----------------------------------------------------------------------===// - -bool FPTruncOp::areCastCompatible(Type a, Type b) { - if (auto fa = a.dyn_cast()) - if (auto fb = b.dyn_cast()) - return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; -} - -//===----------------------------------------------------------------------===// -// AtomicRMWOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(AtomicRMWOp op) { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) - return op.emitOpError( - "expects the number of subscripts to be equal to memref rank"); - switch (op.kind()) { - case AtomicRMWKind::addf: - case AtomicRMWKind::maxf: - case AtomicRMWKind::minf: - case AtomicRMWKind::mulf: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.kind()) - << "' expects a floating-point type"; - break; - case AtomicRMWKind::addi: - case AtomicRMWKind::maxs: - case AtomicRMWKind::maxu: - case AtomicRMWKind::mins: - case AtomicRMWKind::minu: - case AtomicRMWKind::muli: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.kind()) - << "' expects an integer type"; - break; - default: - break; - } - return success(); -} - -//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===//