diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -56,12 +56,78 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: matmul_unsigned + cpp_class_name: MatmulUnsignedOp + doc: |- + Performs a unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + - !LinalgOperandDefConfig + name: B + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> + - !LinalgOperandDefConfig + name: C + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + is_unsigned_cast: true + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B + is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul @@ -132,12 +198,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: AZp + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub @@ -148,12 +216,14 @@ operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: BZp + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d @@ -221,12 +291,14 @@ operands: - !ScalarExpression scalar_arg: lhs + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: AccumType operands: - !ScalarExpression scalar_arg: rhs + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul @@ -284,12 +356,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul @@ -361,12 +435,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: AZp + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub @@ -377,12 +453,14 @@ operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: BZp + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec @@ -438,12 +516,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: y + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat @@ -499,12 +579,14 @@ operands: - !ScalarExpression scalar_arg: y + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matvec @@ -561,12 +643,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot @@ -621,12 +705,14 @@ operands: - !ScalarExpression scalar_arg: A + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: B + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d @@ -682,12 +768,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d @@ -745,12 +833,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d @@ -811,12 +901,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d_nwc_wcf @@ -887,12 +979,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf @@ -975,12 +1069,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q @@ -1080,12 +1176,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: IZp + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub @@ -1096,12 +1194,14 @@ operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: KZp + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw @@ -1184,12 +1284,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d_ndhwc_dhwcf @@ -1272,12 +1374,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv2D_nhw @@ -1353,12 +1457,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv2D_nhw_q @@ -1449,12 +1555,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: IZp + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub @@ -1465,12 +1573,14 @@ operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: KZp + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv2D_nhwc @@ -1549,12 +1659,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv2D_nhwc_q @@ -1649,12 +1761,14 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: IZp + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub @@ -1665,12 +1779,14 @@ operands: - !ScalarExpression scalar_arg: K + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: U operands: - !ScalarExpression scalar_arg: KZp + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum @@ -1741,6 +1857,7 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max @@ -1811,6 +1928,78 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nhwc_max_unsigned + cpp_class_name: PoolingNhwcMaxUnsignedOp + doc: |- + Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * + s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, + s9)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d3, d4)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d0, d1, d2, d5)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: max_unsigned + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nchw_max @@ -1881,6 +2070,7 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min @@ -1951,6 +2141,78 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nhwc_min_unsigned + cpp_class_name: PoolingNhwcMinUnsignedOp + doc: |- + Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 * + s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5, + s9)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d3, d4)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] + -> (d0, d1, d2, d5)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: min_unsigned + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum @@ -2027,6 +2289,7 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_max @@ -2103,6 +2366,7 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_min @@ -2179,6 +2443,7 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d @@ -2246,6 +2511,7 @@ operands: - !ScalarExpression scalar_const: '2147483647 : i64' + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: F64 @@ -2268,6 +2534,7 @@ operands: - !ScalarExpression scalar_index: 1 + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: add @@ -2286,6 +2553,7 @@ operands: - !ScalarExpression scalar_index: 0 + is_unsigned_cast: false - !ScalarExpression scalar_arg: seed - !ScalarExpression @@ -2294,24 +2562,29 @@ operands: - !ScalarExpression scalar_const: '1103515245 : i64' + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' + is_unsigned_cast: false + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: mul @@ -2330,8 +2603,10 @@ operands: - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' + is_unsigned_cast: false - !ScalarExpression scalar_arg: min + is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: soft_plus_2d @@ -2377,6 +2652,7 @@ operands: - !ScalarExpression scalar_const: '1.000000e+00 : f64' + is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: exp @@ -2387,3 +2663,4 @@ operands: - !ScalarExpression scalar_arg: I + is_unsigned_cast: false diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -196,7 +196,7 @@ // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). - Value cast(Type toType, Value operand) { + Value cast(Type toType, Value operand, bool isUnsignedCast) { OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); @@ -204,23 +204,32 @@ return operand; if (auto toIntType = toType.dyn_cast()) { // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); + } // Cast index operands directly to the int type. if (operand.getType().isIndex()) return builder.create(loc, toType, operand); if (auto fromIntType = operand.getType().dyn_cast()) { - // Either sign extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); + } if (toIntType.getWidth() < fromIntType.getWidth()) return builder.create(loc, toType, operand); } } else if (auto toFloatType = toType.dyn_cast()) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toFloatType, operand); return builder.create(loc, toFloatType, operand); + } if (auto fromFloatType = operand.getType().dyn_cast()) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); @@ -284,6 +293,15 @@ llvm_unreachable("unsupported non numeric type"); } + Value applyfn__max_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + Value applyfn__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) @@ -293,6 +311,15 @@ llvm_unreachable("unsupported non numeric type"); } + Value applyfn__min_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + void yieldOutputs(ValueRange values) { assert(!values.empty() && "linalg ops must yield outputs"); if (values.empty()) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -340,6 +340,8 @@ max = PrimFnType("max") min = PrimFnType("min") sub = PrimFnType("sub") + max_unsigned = PrimFnType("max_unsigned") + min_unsigned = PrimFnType("min_unsigned") class ReduceFnType: @@ -365,6 +367,8 @@ mul = PrimFn.mul.reduce max = PrimFn.max.reduce min = PrimFn.min.reduce + max_unsigned = PrimFn.max_unsigned.reduce + min_unsigned = PrimFn.min_unsigned.reduce class PrimApply(TensorExpression): @@ -438,8 +442,8 @@ self.operand = operand def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, - self.operand.to_scalar_expression()).expr() + return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), + False).expr() def visit_tensor_exprs(self, callback): super().visit_tensor_exprs(callback) @@ -449,6 +453,17 @@ return f"cast({self.to_type}, {repr(self.operand)})" +class cast_unsigned(cast): + """Casts the element type to an unsigned type (typically symbolic TypeVar).""" + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), + True).expr() + + def __repr__(self): + return f"cast_unsigned({self.to_type}, {repr(self.operand)})" + + class ReduceApply(TensorExpression): """Application of a reduction. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -230,10 +230,12 @@ return fn(*operand_values) elif expr.symbolic_cast: operand_value = self.expression(expr.symbolic_cast.operand) - return self.cast(expr.symbolic_cast.to_type.name, operand_value) + return self.cast(expr.symbolic_cast.to_type.name, operand_value, + expr.symbolic_cast.is_unsigned_cast) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def cast(self, type_var_name: str, operand: Value) -> Value: + def cast(self, type_var_name: str, operand: Value, + is_unsigned_cast: bool) -> Value: try: to_type = self.type_mapping[type_var_name] except KeyError: @@ -242,29 +244,37 @@ if operand.type == to_type: return operand if _is_integer_type(to_type): - return self._cast_to_integer(to_type, operand) + return self._cast_to_integer(to_type, operand, is_unsigned_cast) elif _is_floating_point_type(to_type): - return self._cast_to_floating_point(to_type, operand) + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) - def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: + def _cast_to_integer(self, to_type: Type, operand: Value, + is_unsigned_cast: bool) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return std.FPToUIOp(to_type, operand).result return std.FPToSIOp(to_type, operand).result if _is_index_type(operand_type): return std.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: + if is_unsigned_cast: + return std.ZeroExtendIOp(to_type, operand).result return std.SignExtendIOp(to_type, operand).result elif to_width < from_width: return std.TruncateIOp(to_type, operand).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value: + def _cast_to_floating_point(self, to_type: Type, operand: Value, + is_unsigned_cast: bool) -> Value: operand_type = operand.type if _is_integer_type(operand_type): + if is_unsigned_cast: + return std.UIToFPOp(to_type, operand).result return std.SIToFPOp(to_type, operand).result # Assume FloatType. to_width = _get_floating_point_width(to_type) @@ -324,6 +334,13 @@ return std.MaxSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") + def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MaxFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.MaxUIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") + def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.MinFOp(lhs.type, lhs, rhs).result @@ -331,6 +348,12 @@ return std.MinSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") + def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MinFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.MinUIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") def _infer_structured_outs(op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -85,15 +85,17 @@ class ScalarSymbolicCast: """A type of ScalarExpression that symbolically casts an operand to a TypeVar.""" - def __init__(self, to_type: TypeVar, operand: "ScalarExpression"): + def __init__(self, to_type: TypeVar, operand: "ScalarExpression", + is_unsigned_cast: bool): self.to_type = to_type self.operand = operand + self.is_unsigned_cast = is_unsigned_cast def expr(self) -> "ScalarExpression": return ScalarExpression(symbolic_cast=self) def __repr__(self): - return f"ScalarSymbolicCast({self.to_type}, {self.operand})" + return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})" class ScalarExpression(YAMLObject): @@ -144,7 +146,8 @@ return dict( symbolic_cast=dict( type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand])) + operands=[self.symbolic_cast.operand], + is_unsigned_cast=self.symbolic_cast.is_unsigned_cast)) else: raise ValueError(f"Unexpected ScalarExpression type: {self}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -20,6 +20,20 @@ implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def matmul_unsigned( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs an unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), @@ -411,6 +425,24 @@ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_nchw_max( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), @@ -447,6 +479,23 @@ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op def pooling_ndhwc_sum( diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -1,35 +1,108 @@ // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s -func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) +// Verifies that different argument types is legal. +func @generalize_matmul_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } -// CHECK-LABEL: @generalize_matmul_tensor_f32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) -// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK-LABEL: @generalize_matmul_tensor_f16f64f32 +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) +// Verify floating point extension and truncation. +// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 +// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32 +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 // CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 // CHECK-NEXT: linalg.yield %[[ADD]] : f32 // CHECK-NEXT: -> tensor<16x32xf32> // ----- -func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) +// Verifies that different argument types is legal. +func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } -// CHECK-LABEL: @generalize_matmul_tensor_i32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32) -// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32 +// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32 +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i16, %[[B_ARG:.+]]: i64, %[[C_ARG:.+]]: i32) +// Verify signed integer extension and truncation. +// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i16 to i32 +// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i64 to i32 +// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 // CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 // CHECK-NEXT: linalg.yield %[[ADD]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> // ----- +func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: @generalize_matmul_tensor_i16i64f32 +// Verify signed integer to floating point cast. +// CHECK: = sitofp +// CHECK: = sitofp + +// ----- + +func @generalize_matmul_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_matmul_tensor_f16f64i32 +// Verify floating point to signed integer cast. +// CHECK: = fptosi +// CHECK: = fptosi + +// ----- + +func @generalize_matmul_unsigned_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64i32 +// Verify unsigned integer extension and truncation. +// CHECK: = zexti +// CHECK: = trunci + +// ----- + +func @generalize_matmul_unsigned_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64f32 +// Verify unsigned integer to floating point cast. +// CHECK: = uitofp +// CHECK: = uitofp + +// ----- + +func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_matmul_unsigned_tensor_f16f64i32 +// Verify floating point to unsigend integer cast. +// CHECK: = fptoui +// CHECK: = fptoui + +// ----- + func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> @@ -51,10 +124,20 @@ } // CHECK-LABEL: @generalize_pooling_nhwc_max_i32 -// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) -// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32 -// CHECK-NEXT: linalg.yield %[[MAX]] : i32 -// CHECK-NEXT: -> tensor<1x2x4x1xi32> +// Verify signed integer maximum. +// CHECK: = maxsi + +// ----- + +func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> + return %0: tensor<1x2x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32 +// Verify unsigned integer minimum. +// CHECK: = maxui // ----- @@ -79,10 +162,20 @@ } // CHECK-LABEL: @generalize_pooling_nhwc_min_i32 -// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) -// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32 -// CHECK-NEXT: linalg.yield %[[MIN]] : i32 -// CHECK-NEXT: -> tensor<1x2x4x1xi32> +// Verify signed integer minimum. +// CHECK: = minsi + +// ----- + +func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} + ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> + return %0: tensor<1x2x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32 +// Verify unsigned integer minimum. +// CHECK: = minui // ----- @@ -169,122 +262,3 @@ // CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 // CHECK-NEXT: linalg.yield %[[LOG]] : f32 // CHECK-NEXT: -> tensor<16x32xf32> - -// ----- -// Verifies floating point to integer cast. -func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) - outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> - return %0: tensor<16x32xi16> -} - -// CHECK-LABEL: @generalize_matmul_tensor_f32_f32_i16 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16) -// CHECK-NEXT: %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16 -// CHECK-NEXT: %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16 -// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 -// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 -// CHECK-NEXT: linalg.yield %[[ADD]] : i16 -// CHECK-NEXT: -> tensor<16x32xi16> - -// ----- -// Verifies sign extension cast. -func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) - outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> - return %0: tensor<16x32xi32> -} - -// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_i32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) -// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 -// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32 -// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 -// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 -// CHECK-NEXT: linalg.yield %[[ADD]] : i32 -// CHECK-NEXT: -> tensor<16x32xi32> - -// ----- -// Verifies that different argument types is legal. -func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>) - outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> - return %0: tensor<16x32xi32> -} - -// CHECK-LABEL: @generalize_matmul_tensor_i8_i16_i32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) -// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 -// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32 -// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 -// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 -// CHECK-NEXT: linalg.yield %[[ADD]] : i32 -// CHECK-NEXT: -> tensor<16x32xi32> - -// ----- -// Somewhat non-sensical but checks integer truncation cast. -func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) - outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> - return %0: tensor<16x32xi16> -} - -// CHECK-LABEL: @generalize_matmul_tensor_i32_i32_i16 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) -// CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16 -// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16 -// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 -// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 -// CHECK-NEXT: linalg.yield %[[ADD]] : i16 -// CHECK-NEXT: -> tensor<16x32xi16> - -// ----- -// Verifies integer to floating point cast. -func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) - outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> - return %0: tensor<16x32xf32> -} - -// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_f32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) -// CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32 -// CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32 -// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 -// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 -// CHECK-NEXT: linalg.yield %[[ADD]] : f32 -// CHECK-NEXT: -> tensor<16x32xf32> - -// ----- -// Verifies floating point extension cast. -func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>) - outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> - return %0: tensor<16x32xf32> -} - -// CHECK-LABEL: @generalize_matmul_tensor_f16_f16_f32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) -// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 -// CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32 -// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 -// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 -// CHECK-NEXT: linalg.yield %[[ADD]] : f32 -// CHECK-NEXT: -> tensor<16x32xf32> - -// ----- -// Verifies floating point truncation. -func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>) - outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> - return %0: tensor<16x32xf32> -} - -// CHECK-LABEL: @generalize_matmul_tensor_f64_f64_f32 -// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) -// CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32 -// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32 -// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 -// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 -// CHECK-NEXT: linalg.yield %[[ADD]] : f32 -// CHECK-NEXT: -> tensor<16x32xf32> diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -43,12 +43,14 @@ operands: - !ScalarExpression scalar_const: '42 : i64' + is_unsigned_cast: false - !ScalarExpression symbolic_cast: type_var: T operands: - !ScalarExpression scalar_index: 1 + is_unsigned_cast: true # ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" @@ -84,9 +86,9 @@ # IMPL-LABEL: void Test1Op::regionBuilder( # IMPL: ImplicitLocOpBuilder &b, Block &block) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]], false); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]], true); # IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]); diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -29,6 +29,15 @@ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def matmul_unsigned_poly( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + domain(D.m, D.n, D.k) + C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + + @linalg_structured_op def conv_poly( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), @@ -54,6 +63,17 @@ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_max_unsigned_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op def pooling_min_poly( @@ -67,6 +87,17 @@ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_min_unsigned_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op def fill_rng_poly( @@ -147,6 +178,15 @@ def test_i8i8i32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) + # CHECK-LABEL: @test_i8i8i32_matmul_unsigned + # CHECK: = zexti + # CHECK: = zexti + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), i32)) + def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result): + return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + # CHECK-LABEL: @test_i8i16i32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) # CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 @@ -189,6 +229,15 @@ def test_i8i8f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) + # CHECK-LABEL: @test_i8i8f32_matmul_unsigned + # CHECK: = uitofp + # CHECK: = uitofp + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), f32)) + def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result): + return matmul_unsigned_poly(lhs, rhs, outs=[init_result]) + # CHECK-LABEL: @test_f16f16f32_matmul # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) # CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 @@ -252,6 +301,16 @@ return pooling_max_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + # CHECK-LABEL: @test_f32i32_max_unsigned_pooling + # CHECK: = fptoui + # CHECK: = maxui + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), + RankedTensorType.get((2, 4), i32)) + def test_f32i32_max_unsigned_pooling(input, shape, init_result): + return pooling_max_unsigned_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + # CHECK-LABEL: @test_f32f32_max_pooling # CHECK: linalg.generic # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] @@ -268,6 +327,7 @@ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_min_pooling + # CHECK: = fptosi # CHECK: = minsi @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), @@ -276,6 +336,16 @@ return pooling_min_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + # CHECK-LABEL: @test_f32i32_min_unsigned_pooling + # CHECK: = fptoui + # CHECK: = minui + @builtin.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), + RankedTensorType.get((2, 4), i32)) + def test_f32i32_min_unsigned_pooling(input, shape, init_result): + return pooling_min_unsigned_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + # CHECK-LABEL: @test_f32f32_min_pooling # CHECK: = minf @builtin.FuncOp.from_py_func( diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -95,6 +95,7 @@ // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; + bool isUnsignedCast; }; struct ScalarExpression { @@ -278,6 +279,7 @@ static void mapping(IO &io, ScalarSymbolicCast &info) { io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); + io.mapRequired("is_unsigned_cast", info.isUnsignedCast); } }; @@ -986,9 +988,10 @@ return None; } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); - stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});", - cppIdent, typeCppValue.getValue(), - *operandCppValue)); + stmts.push_back( + llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent, + typeCppValue.getValue(), *operandCppValue, + expression.symbolicCast->isUnsignedCast)); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type";