diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1,7 +1,7 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul - cpp_op_name: MatmulOp + cpp_class_name: MatmulOp doc: |- Performs a matrix multiplication of two 2D inputs. @@ -63,7 +63,7 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul - cpp_op_name: BatchMatmulOp + cpp_class_name: BatchMatmulOp doc: |- Performs a batched matrix multiplication of two 3D inputs. @@ -126,7 +126,7 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec - cpp_op_name: MatvecOp + cpp_class_name: MatvecOp doc: |- Performs a matrix-vector multiplication. @@ -187,7 +187,7 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat - cpp_op_name: VecmatOp + cpp_class_name: VecmatOp doc: |- Performs a vector-matrix multiplication. @@ -248,7 +248,7 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot - cpp_op_name: DotOp + cpp_class_name: DotOp doc: |- Performs a dot product of two vectors to a scalar result. @@ -305,4 +305,151 @@ operands: - !ScalarExpression scalar_arg: B +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: fill_rng_2d + cpp_class_name: FillRng2DOp + doc: |- + Fills the output tensor with pseudo random numbers. + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: O + usage: output + shape: affine_map<()[s0, s1] -> (s0, s1)> + element_type_var: T + captures: + - ! + name: min + type_var: F64 + - ! + name: max + type_var: F64 + - ! + name: seed + type_var: I32 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + iterator_types: + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + symbolic_cast: + type_var: T + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + symbolic_cast: + type_var: F64 + operands: + - !ScalarExpression + scalar_const: '2147483647 : i64' + - !ScalarExpression + symbolic_cast: + type_var: F64 + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_index: 1 + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_index: 0 + - !ScalarExpression + scalar_capture: seed + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_const: '1103515245 : i64' + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_const: '12345 : i64' + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_const: '1103515245 : i64' + - !ScalarExpression + symbolic_cast: + type_var: I32 + operands: + - !ScalarExpression + scalar_const: '12345 : i64' + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + scalar_capture: max + - !ScalarExpression + scalar_capture: min + - !ScalarExpression + symbolic_cast: + type_var: F64 + operands: + - !ScalarExpression + scalar_const: '2.3283063999999999E-10 : f64' + - !ScalarExpression + scalar_capture: min \ No newline at end of file diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -62,26 +62,36 @@ static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes); + SmallVectorImpl &outputTypes, + SmallVectorImpl &operandSegmentSizes); template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op); /// Specific parsing and printing for named structured ops created by ods-gen. +static ParseResult +parseNamedStructuredOpCaptures(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &captures, + SmallVectorImpl &operandSegmentSizes); + template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures = {}); + ValueRange capturedValues); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures = {}); +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result, + bool parseCaptures = false); + +template +static void printNamedStructuredOpCaptures(OpAsmPrinter &p, + NamedStructuredOpType op); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -220,14 +230,15 @@ class RegionBuilderHelper { public: - RegionBuilderHelper(Block &block) : block(block) {} + RegionBuilderHelper(MLIRContext *context, Block &block) + : context(context), block(block) {} // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand) { - OpBuilder builder = getBuilder(operand); + OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); if (operand.getType() == toType) @@ -236,11 +247,14 @@ // If operand is floating point, cast directly to the int type. if (operand.getType().isa()) return builder.create(loc, toType, operand); + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return builder.create(loc, toType, operand); if (auto fromIntType = operand.getType().dyn_cast()) { // Either sign extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) return builder.create(loc, toType, operand); - else if (toIntType.getWidth() < fromIntType.getWidth()) + if (toIntType.getWidth() < fromIntType.getWidth()) return builder.create(loc, toType, operand); } } else if (auto toFloatType = toType.dyn_cast()) { @@ -251,7 +265,7 @@ if (auto fromFloatType = operand.getType().dyn_cast()) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); - else if (toFloatType.getWidth() < fromFloatType.getWidth()) + if (toFloatType.getWidth() < fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); } } @@ -262,19 +276,28 @@ } Value applyfn__add(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(lhs); + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); - else if (isInteger(lhs)) + if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } + Value applyfn__sub(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + Value applyfn__mul(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(lhs); + OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); - else if (isInteger(lhs)) + if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } @@ -284,18 +307,39 @@ if (values.empty()) return; Value first = values.front(); - OpBuilder builder = getBuilder(first); + OpBuilder builder = getBuilder(); builder.create(first.getLoc(), values); } + Value constant(std::string value) { + OpBuilder builder = getBuilder(); + Location loc = builder.getUnknownLoc(); + Attribute valueAttr = parseAttribute(value, builder.getContext()); + return builder.create(loc, valueAttr.getType(), valueAttr); + } + + Value index(int64_t dim) { + OpBuilder builder = getBuilder(); + return builder.create(builder.getUnknownLoc(), dim); + } + + Type getIntegerType(unsigned width) { + return IntegerType::get(context, width); + } + + Type getFloat32Type() { return Float32Type::get(context); } + + Type getFloat64Type() { return Float64Type::get(context); } + private: + MLIRContext *context; Block █ bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } - OpBuilder getBuilder(Value value) { - OpBuilder builder(value.getContext()); + OpBuilder getBuilder() { + OpBuilder builder(context); builder.setInsertionPointToEnd(&block); return builder; } @@ -613,9 +657,14 @@ dictAttr.getValue().end()); // Parsing is shared with named ops, except for the region. - SmallVector inputTypes, outputTypes; - if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) + SmallVector inputTypes, outputTypes; + SmallVector operandSegmentSizes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, + operandSegmentSizes)) return failure(); + result.addAttribute( + "operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(operandSegmentSizes)); // Optional attributes may be added. if (succeeded(parser.parseOptionalKeyword("attrs"))) @@ -1524,7 +1573,6 @@ MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); } - template unsigned getMaxPosOfType(ArrayRef exprArrays) { unsigned pos = 0; @@ -2812,9 +2860,10 @@ static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { - llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; - SmallVector inputsOperands, outputsOperands; + SmallVectorImpl &outputTypes, + SmallVectorImpl &operandSegmentSizes) { + llvm::SMLoc inputOperandLocs, outputOperandLocs; + SmallVector inputOperands, outputOperands; parser.parseOptionalAttrDict(result.attributes); @@ -2822,29 +2871,28 @@ if (parser.parseLParen()) return failure(); - inputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(inputsOperands) || + inputOperandLocs = parser.getCurrentLocation(); + if (parser.parseOperandList(inputOperands) || parser.parseColonTypeList(inputTypes) || parser.parseRParen()) return failure(); } if (succeeded(parser.parseOptionalKeyword("outs"))) { - outputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + outputOperandLocs = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputOperands) || parser.parseColonTypeList(outputTypes) || parser.parseRParen()) return failure(); } - if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandLocs, result.operands) || - parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, + parser.resolveOperands(outputOperands, outputTypes, outputOperandLocs, result.operands)) return failure(); - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr( - {static_cast(inputsOperands.size()), - static_cast(outputsOperands.size())})); + // Set the operand segment sizes. + operandSegmentSizes.append({static_cast(inputOperands.size()), + static_cast(outputOperands.size())}); return success(); } @@ -2861,16 +2909,41 @@ // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// +static ParseResult +parseNamedStructuredOpCaptures(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &captures, + SmallVectorImpl &operandSegmentSizes) { + llvm::SMLoc captureOperandLocs; + SmallVector captureOperands; + SmallVector captureTypes; + + if (succeeded(parser.parseOptionalKeyword("captures"))) { + if (parser.parseLParen()) + return failure(); + + captureOperandLocs = parser.getCurrentLocation(); + if (parser.parseOperandList(captureOperands) || + parser.parseColonTypeList(captureTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(captureOperands, captureTypes, captureOperandLocs, + captures)) + return failure(); + + // Add the capture operands and set the operand segment size. + result.addOperands(captures); + operandSegmentSizes.append({static_cast(captureOperands.size())}); + return success(); +} + template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures) { + ValueRange capturedValues) { ParseResult res = success(); OpBuilder opBuilder(parser.getBuilder().getContext()); - // Resolve `captures` into `capturedValues` at parse time so we can build the - // region with captures. - SmallVector capturedValues; fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, capturedValues, [&](unsigned expected, unsigned actual) { @@ -2894,18 +2967,29 @@ } template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures) { - // TODO: Enable when ods-gen supports captures. - assert(captures.empty() && "unexpected captures for named structured ops"); - SmallVector inputTypes, outputTypes; - if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result, + bool parseCaptures) { + SmallVector inputTypes, outputTypes; + SmallVector captures; + SmallVector operandSegmentSizes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, + operandSegmentSizes)) + return failure(); + + // TODO: simplify condition once all named operations have captures. + if (parseCaptures && parseNamedStructuredOpCaptures(parser, result, captures, + operandSegmentSizes)) return failure(); + // Add the captures and set the operand segment sizes. + result.addAttribute( + "operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(operandSegmentSizes)); + // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. - SmallVector outputTensorsTypes; + SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); @@ -2919,6 +3003,16 @@ return success(); } +template +static void printNamedStructuredOpCaptures(OpAsmPrinter &p, + NamedStructuredOpType op) { + // TODO: access captures directly once all named operations have captures. + OperandRange captures = + op->getOperands().drop_front(op.inputs().size() + op.outputs().size()); + if (!captures.empty()) + p << " captures(" << captures << " : " << captures.getTypes() << ")"; +} + static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) @@ -2935,11 +3029,10 @@ // See generated code in mlir-linalg-yaml-gen.cpp "linalg.memoized_indexing_maps"}); - // Printing is shared with generic ops, except for the region and - // attributes. + // Printing is shared with generic ops, except for region, attributes, and + // captures. printCommonStructuredOpParts(p, op); - - // Results printing. + printNamedStructuredOpCaptures(p, op); printNamedStructuredOpResults(p, op.result_tensors().getTypes()); // Region is elided. diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -89,10 +89,10 @@ class StructuredOpMixin: """All structured ops use the same mixin class.""" - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): + def __init__(self, inputs, outputs=(), captures=(), results=(), loc=None, ip=None): super().__init__( self.build_generic(results=list(results), - operands=[list(inputs), list(outputs)], + operands=[list(inputs), list(outputs), list(captures)], loc=loc, ip=ip)) @@ -103,5 +103,6 @@ if ("__init__" not in parent_opview_cls.__dict__ and hasattr(parent_opview_cls, "inputs") and hasattr(parent_opview_cls, "outputs") and + hasattr(parent_opview_cls, "captures") and hasattr(parent_opview_cls, "result_tensors")): return StructuredOpMixin diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -337,16 +337,18 @@ class const(TensorExpression): """Returns the given constant floating point or integer value.""" - def __init__(self, type_var: TypeVar, value: Any): - if not isinstance(type_var, TypeVar): - raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}") - if not (isinstance(value, float) or isinstance(value, int)): - raise ValueError(f"const requires int or float. Got: {type(value)}") - self.type_var = type_var - self.value = value + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) + else: + raise ValueError(f"const requires int or float. Got: {type(value)}") def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.type_var, self.value).expr() + return ScalarConst(self.value).expr() def __repr__(self): return f"const({self.type_var}, {self.value})" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -299,16 +299,15 @@ exprs=list(affine_map.results)) def to_yaml_custom_dict(self): - self_dict = dict( - args=self.ordered_tensor_args, - captures=self.ordered_capture_args, - # TODO: Refactor the hierarchy internally when supporting more - # than static (preserving this serialized form). - indexing_maps=LinalgIndexingMapsConfig( - static_indexing_maps=self.indexing_maps), - iterator_types=self.iterator_types, - assignments=self.assignments, - ) + self_dict = dict(args=self.ordered_tensor_args) + if len(self.ordered_capture_args) != 0: + self_dict["captures"] = self.ordered_capture_args + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments return self_dict def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -153,9 +153,9 @@ raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + named_op = getattr(linalg, op_class_name)(ins, outs, captures, result_types) linalgDialect = ctx.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, named_op.operation) + fill_builtin_region(linalgDialect, named_op.operation, list(captures)) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. @@ -201,9 +201,12 @@ raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for " f"this structured op.") elif expr.scalar_const: - return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value) + value_attr = Attribute.parse(expr.scalar_const.value) + return std.ConstantOp(value_attr.type, value_attr).result elif expr.scalar_index: - return self.index(expr.scalar_index.dim) + dim_attr = IntegerAttr.get(IntegerType.get_signless(64), + expr.scalar_index.dim) + return linalg.IndexOp(IndexType.get(), dim_attr).result elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -220,25 +223,6 @@ return self.cast(expr.symbolic_cast.to_type.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def constant(self, type_var_name: str, value: Any) -> Value: - try: - type = self.type_mapping[type_var_name] - except KeyError: - raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mappings.keys()}") - try: - if(_is_floating_point_type(type)): - return std.ConstantOp(type, FloatAttr.get(type, float(value))).result - elif(_is_integer_type(type)): - return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result - except ValueError: - raise ValueError(f"Unable to cast value {value} to type {type}") - raise NotImplementedError(f"Unimplemented constant type {type}") - - def index(self, dim: int) -> Value: - dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim) - return linalg.IndexOp(IndexType.get(), dim_attr).result - def cast(self, type_var_name: str, operand: Value) -> Value: try: to_type = self.type_mapping[type_var_name] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -71,15 +71,14 @@ class ScalarConst: """A type of ScalarExpression representing a constant.""" - def __init__(self, type_var: TypeVar, value: Any): - self.type_var = type_var + def __init__(self, value: str): self.value = value def expr(self) -> "ScalarExpression": return ScalarExpression(scalar_const=self) def __repr__(self): - return f"(ScalarConst({self.type_var}, {self.value})" + return f"(ScalarConst({self.value})" class ScalarIndex: """A type of ScalarExpression accessing an iteration index.""" @@ -151,8 +150,7 @@ elif self.scalar_capture: return dict(scalar_capture=self.scalar_capture.capture) elif self.scalar_const: - return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name, - attributes=[self.scalar_const.value])) + return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: return dict(scalar_index=self.scalar_index.dim) elif self.symbolic_cast: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -68,3 +68,27 @@ """ implements(ContractionOpInterface) C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + +@linalg_structured_op +def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True), + min=CaptureDef(F64), + max=CaptureDef(F64), + seed=CaptureDef(I32)): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + """ + multiplier = cast(I32, const(1103515245)) + increment = cast(I32, const(12345)) + rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = cast(F64, const(2.3283064e-10)) + offset = cast(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -28,6 +28,50 @@ // CHECK-NEXT: linalg.yield %[[ADD]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> +// ----- + +func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>, %min: f64, %max: f64, %seed: i32) -> tensor<16x32xf32> { + %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) + captures(%min, %max, %seed: f64, f64, i32) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: @generalize_fill_rng_2d_f32 +// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32) +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 +// CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32 +// CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32 +// CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32 +// CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32 +// CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32 +// CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32 +// Skip random number computation for the second index. +// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 +// CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64 +// CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64 +// CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64 +// CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64 +// CHECK-DAG: %[[RND6:.+]] = fptrunc %[[RND5]] : f64 to f32 +// CHECK-NEXT: linalg.yield %[[RND6]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> + +// ----- + +func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>, %min: f64, %max: f64, %seed: i32) -> tensor<16x32xi32> { + %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) + captures(%min, %max, %seed: f64, f64, i32) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK-LABEL: @generalize_fill_rng_2d_i32 +// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32) +// Verifies floating point to integer cast. +// CHECK: %[[RND6:.+]] = fptosi %{{.+}} : f64 to i32 +// CHECK-NEXT: linalg.yield %[[RND6]] : i32 +// CHECK-NEXT: -> tensor<16x32xi32> + // ----- // Verifies floating point to integer cast. func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -21,7 +21,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.test'] +config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.yaml', '.py', '.test'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -0,0 +1,152 @@ +# RUN: mlir-linalg-ods-yaml-gen %s --o-ods-decl=- | FileCheck %s --check-prefix=ODS +# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL + +# @linalg_structured_op +# def test1(O=TensorDef(T, S.M, S.N, output=True), value=CaptureDef(T)): +# """Title. + +# Detailed description. +# """ +# O[D.m, D.n] = cast(T, const(42)) + cast(T, index(D.n)) - value + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test1 + cpp_class_name: Test1Op + doc: |- + Title. + + Detailed description. +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: O + usage: output + shape: affine_map<()[s0, s1] -> (s0, s1)> + element_type_var: T + captures: + - ! + name: value + type_var: T + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + iterator_types: + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + symbolic_cast: + type_var: T + operands: + - !ScalarExpression + scalar_const: '42 : i64' + - !ScalarExpression + symbolic_cast: + type_var: T + operands: + - !ScalarExpression + scalar_index: 1 + - !ScalarExpression + scalar_capture: value + +# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" + +# ODS: let summary = [{ Title. }]; +# ODS-NEXT: let description = [{ +# ODS-NEXT: Detailed description. +# ODS-NEXT: }]; + +# ODS: let arguments = +# ODS-NEXT: Variadic:$inputs, +# ODS-NEXT: Variadic:$outputs, +# ODS-NEXT: Variadic:$captures + +# ODS: let builders = +# ODS: $_state.addOperands(inputs); +# ODS-NEXT: $_state.addOperands(outputs); +# ODS-NEXT: $_state.addOperands(captures); +# ODS-NEXT: $_state.addAttribute( +# ODS-NEXT: "operand_segment_sizes", +# ODS-NEXT: $_builder.getI32VectorAttr({ +# ODS-NEXT: static_cast(inputs.size()), +# ODS-NEXT: static_cast(outputs.size()), +# ODS-NEXT: static_cast(captures.size())})); +# ODS-NEXT: createAndFillStructuredOpRegion( +# ODS-NEXT: $_builder, +# ODS-NEXT: $_state, +# ODS-NEXT: TypeRange(inputs), +# ODS-NEXT: TypeRange(outputs), +# ODS-NEXT: captures); + +# IMPL-LABEL: void Test1Op::regionBuilder +# IMPL-SAME: (Block &block, ValueRange captures) +# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL5:[a-z0-9]+]] = helper.applyfn__sub([[VAL4]], captures[0]); + + +# @linalg_structured_op +# def test2(I=TensorDef(T, S.M, S.N), +# O=TensorDef(T, S.M, S.N, output=True)): +# """Title. + +# Detailed description. +# """ +# O[D.m, D.n] = I[D.n, D.m] + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test2 + cpp_class_name: Test2Op + doc: |- + Title. + + Detailed description. +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: I + usage: input + shape: affine_map<()[s0, s1] -> (s0, s1)> + element_type_var: T + - ! + name: O + usage: output + shape: affine_map<()[s0, s1] -> (s0, s1)> + element_type_var: T + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d1, d0)> + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + iterator_types: + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_arg: I + +# IMPL-LABEL: Test2Op::iterator_types() +# IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() } + +# IMPL: Test2Op::indexing_maps() +# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>" +# IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>" + +# IMPL: void Test2Op::regionBuilder(Block &block, ValueRange captures) +# IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -27,3 +27,58 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + +# CHECK: --- +# CHECK-LABEL: constants +# CHECK: assignments: +# CHECK: - +# CHECK: arg: O +# CHECK: scalar_apply: +# CHECK: fn_name: sub +# CHECK: operands: +# CHECK: scalar_apply: +# CHECK: fn_name: add +# CHECK: operands: +# CHECK: symbolic_cast: +# CHECK: type_var: T +# CHECK: operands: +# CHECK: scalar_const: '3.1415926535897931 : f64' +# CHECK: symbolic_cast: +# CHECK: type_var: T +# CHECK: operands: +# CHECK: scalar_const: '42 : i64' +# CHECK: symbolic_cast: +# CHECK: type_var: T +# CHECK: operands: +# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' +@linalg_structured_op +def constants(O=TensorDef(T, S.M, S.K, output=True)): + pi = cast(T, const(3.1415926535897931)) + cst42 = cast(T, const(42)) + cst1000 = cast(T, const(1e+3)) + O[D.m, D.n] = pi + cst42 - cst1000 + +# CHECK: --- +# CHECK-LABEL: indices +# CHECK: assignments: +# CHECK: - +# CHECK: arg: O +# CHECK: scalar_apply: +# CHECK: fn_name: add +# CHECK: operands: +# CHECK: scalar_index: 1 +# CHECK: scalar_index: 0 +@linalg_structured_op +def indices(O=TensorDef(T, S.M, S.K, output=True)): + O[D.m, D.n] = index(D.n) + index(D.m) + +# CHECK: --- +# CHECK-LABEL: fill +# CHECK: assignments: +# CHECK: - +# CHECK: arg: O +# CHECK: scalar_capture: value +@linalg_structured_op +def fill(O=TensorDef(T, S.M, S.K, output=True), + value=CaptureDef(T)): + O[D.m, D.n] = value \ No newline at end of file diff --git a/mlir/test/python/dialects/linalg/opdsl/captures.py b/mlir/test/python/dialects/linalg/opdsl/captures.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/linalg/opdsl/captures.py @@ -0,0 +1,14 @@ +# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s + +from mlir.dialects.linalg.opdsl.lang import * + +# CHECK: --- +# CHECK-LABEL: fill +# CHECK: captures: +# CHECK: - ! +# CHECK: name: value +# CHECK: type_var: T +@linalg_structured_op +def fill(O=TensorDef(T, S.M, S.K, output=True), + value=CaptureDef(T)): + O[D.m, D.n] = value \ No newline at end of file diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -24,17 +24,18 @@ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op -def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True), - min=CaptureDef(F64), - max=CaptureDef(F64), - seed=CaptureDef(I32)): - multiplier = const(I32, 1103515245) - increment = const(I32, 12345) - temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment - inv_randmax = const(F64, 2.3283064e-10) - scaling = (max - min) * inv_randmax - A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min) +def fill_rng(O=TensorDef(T, S.M, S.N, output=True), + min=CaptureDef(F64), + max=CaptureDef(F64), + seed=CaptureDef(I32)): + multiplier = cast(I32, const(1103515245)) + increment = cast(I32, const(12345)) + rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = cast(F64, const(2.3283064e-10)) + offset = cast(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) with Context() as ctx, Location.unknown(): module = Module.create() @@ -154,27 +155,29 @@ def test_f64f64f32_matmul(lhs, rhs, init_result): return matmul_poly(lhs, rhs, outs=[init_result]) - # CHECK-LABEL: @test_fill_rng_2d + # CHECK-LABEL: @test_fill_rng # CHECK-SAME: %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32 # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 # CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32 # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32 - # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32 - # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32 - # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32 - # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32 - # CHECK: %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64 + # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64 + # CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32 + # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i64 + # CHECK-DAG: %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32 + # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32 + # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32 + # Skip random number computation for the second index. # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 - # CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64 - # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64 - # CHECK-DAG: %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64 + # CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64 + # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64 + # CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64 # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64 # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32 @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32), f64, f64, i32) - def test_fill_rng_2d(init_result, min, max, seed): - return fill_rng_2d(outs=[init_result], captures=[min, max, seed]) + def test_fill_rng(init_result, min, max, seed): + return fill_rng(outs=[init_result], captures=[min, max, seed]) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/tensors.py b/mlir/test/python/dialects/linalg/opdsl/tensors.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/linalg/opdsl/tensors.py @@ -0,0 +1,24 @@ +# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s + +from mlir.dialects.linalg.opdsl.lang import * + +# CHECK: --- +# CHECK-LABEL: matmul +# CHECK: args: +# CHECK: name: A +# CHECK: usage: input +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: element_type_var: T +# CHECK: name: B +# CHECK: usage: input +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)> +# CHECK: element_type_var: T +# CHECK: name: C +# CHECK: usage: output +# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)> +# CHECK: element_type_var: U +@linalg_structured_op +def matmul(A=TensorDef(T, S.M, S.K), + B=TensorDef(T, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -164,7 +164,7 @@ # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : + # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1, 0]> : vector<3xi32>} : # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result]) diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py --- a/mlir/test/python/dialects/linalg/opsrun.py +++ b/mlir/test/python/dialects/linalg/opsrun.py @@ -14,7 +14,7 @@ print(*args, file=sys.stderr) sys.stderr.flush() -boilerplate = """ +matmul_boiler = """ func @main() -> f32 attributes {llvm.emit_c_interface} { %v0 = constant 0.0 : f32 %v1 = constant 1.0 : f32 @@ -27,7 +27,7 @@ linalg.fill(%B, %v2) : memref<16x8xf32>, f32 linalg.fill(%C, %v0) : memref<4x8xf32>, f32 - call @matmul_on_buffers(%A, %B, %C) : + call @matmul_on_buffers(%A, %B, %C) : (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () %c0 = constant 0 : index @@ -38,7 +38,26 @@ } """ -def transform(module): +fill_boiler = """ +func @main() -> i32 attributes {llvm.emit_c_interface} { + %min = constant -1000.0 : f64 + %max = constant 1000.0 : f64 + %seed = constant 42 : i32 + + %O = memref.alloc() : memref<4x16xi32> + + call @fill_on_buffers(%O, %min, %max, %seed) : + (memref<4x16xi32>, f64, f64, i32) -> () + + %c0 = constant 0 : index + %0 = memref.load %O[%c0, %c0] : memref<4x16xi32> + + // TODO: FFI-based solution to allow testing and printing with python code. + return %0 : i32 +} +""" + +def transform(module, boilerplate): import mlir.conversions import mlir.dialects.linalg.passes import mlir.transforms @@ -48,13 +67,13 @@ mod = Module.parse( str(module.operation.regions[0].blocks[0].operations[0].operation) + boilerplate) - pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + - "convert-vector-to-llvm," + + pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + + "convert-vector-to-llvm," + "convert-std-to-llvm") pm.run(mod) return mod -def test_builtin(): +def test_matmul_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() @@ -64,8 +83,8 @@ MemRefType.get((4, 8), f32)) def matmul_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out]) - - execution_engine = ExecutionEngine(transform(module)) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. # Prepare arguments: one result f32. @@ -77,9 +96,9 @@ log('RESULT: ', res[0]) # CHECK: RESULT: 32.0 -test_builtin() +test_matmul_builtin() -def test_generic(): +def test_matmul_generic(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() @@ -89,8 +108,8 @@ MemRefType.get((4, 8), f32)) def matmul_on_buffers(lhs, rhs, out): linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) - - execution_engine = ExecutionEngine(transform(module)) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. # Prepare arguments: one result f32. @@ -102,4 +121,55 @@ log('RESULT: ', res[0]) # CHECK: RESULT: 32.0 -test_generic() +test_matmul_generic() + +def test_fill_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32), + f64, f64, i32) + def fill_on_buffers(out, min, max, seed): + linalg.fill_rng_2d(outs=[out], captures=[min, max, seed]) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log('RESULT: ', res[0]) + # CHECK: RESULT: -480 + +test_fill_builtin() + +def test_fill_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), i32), + f64, f64, i32) + def fill_on_buffers(out, min, max, seed): + linalg.fill_rng_2d(outs=[out], captures=[min, max, seed], + emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log('RESULT: ', res[0]) + # CHECK: RESULT: -480 + +test_fill_generic() \ No newline at end of file diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -51,7 +51,7 @@ struct LinalgOpMetadata { std::string name; - std::string cppOpName; + std::string cppClassName; Optional doc; SmallVector implements; }; @@ -75,6 +75,11 @@ std::string elementTypeVar; }; +struct LinalgCaptureDef { + std::string name; + std::string typeVar; +}; + enum class LinalgIteratorTypeDef { parallel, reduction, @@ -102,6 +107,9 @@ struct ScalarExpression { Optional arg; + Optional capture; + Optional constant; + Optional index; Optional apply; Optional symbolicCast; }; @@ -113,9 +121,10 @@ struct LinalgStructuredOpConfig { SmallVector args; + Optional> captures; LinalgIndexingMapsConfig indexingMaps; SmallVector iteratorTypes; - SmallVector assignments; + std::vector assignments; }; struct LinalgOpConfig { @@ -130,6 +139,7 @@ //===----------------------------------------------------------------------===// LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef) +LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgCaptureDef) LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) @@ -160,6 +170,7 @@ struct MappingTraits { static void mapping(IO &io, LinalgStructuredOpConfig &info) { io.mapRequired("args", info.args); + io.mapOptional("captures", info.captures); io.mapRequired("indexing_maps", info.indexingMaps); io.mapRequired("iterator_types", info.iteratorTypes); io.mapRequired("assignments", info.assignments); @@ -184,6 +195,18 @@ } }; +/// Maps a named capture-argument to an operation, consisting of: +/// - `name`: Must be unique within the operation. +/// - `type_var`: The symbolic type variable that binds to the type of this +/// CaptureDef. +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgCaptureDef &info) { + io.mapRequired("name", info.name); + io.mapRequired("type_var", info.typeVar); + } +}; + /// Usage enum for a named argument. template <> struct ScalarEnumerationTraits { @@ -208,7 +231,7 @@ struct MappingTraits { static void mapping(IO &io, LinalgOpMetadata &info) { io.mapRequired("name", info.name); - io.mapRequired("cpp_op_name", info.cppOpName); + io.mapRequired("cpp_class_name", info.cppClassName); io.mapOptional("doc", info.doc); io.mapOptional("implements", info.implements); } @@ -247,6 +270,9 @@ struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); + io.mapOptional("scalar_capture", info.capture); + io.mapOptional("scalar_const", info.constant); + io.mapOptional("scalar_index", info.index); io.mapOptional("scalar_apply", info.apply); io.mapOptional("symbolic_cast", info.symbolicCast); } @@ -371,16 +397,50 @@ } static Optional -findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl &args) { +findCaptureDefArgIndex(StringRef name, + llvm::Optional> captures) { + if (!captures.hasValue()) + return None; + for (auto it : llvm::enumerate(captures.getValue())) { + if (it.value().name == name) + return it.index(); + } + return None; +} + +// Try to map the TypeVar to a predefined, argument, or capture type. +static Optional +findTypeValue(StringRef typeVar, SmallVectorImpl &args, + Optional> captures) { + // Handle all predefined types. + if (typeVar == "I32") + return std::string("helper.getIntegerType(32)"); + if (typeVar == "I64") + return std::string("helper.getIntegerType(64)"); + if (typeVar == "F32") + return std::string("helper.getFloat32Type()"); + if (typeVar == "F64") + return std::string("helper.getFloat64Type()"); + + // Search all argument types. for (auto it : llvm::enumerate(args)) { if (it.value().elementTypeVar == typeVar) - return it.index(); + return llvm::formatv("block.getArgument({0}).getType()", it.index()) + .str(); + } + + // Search all capture types. + if (!captures.hasValue()) + return None; + for (auto it : llvm::enumerate(captures.getValue())) { + if (it.value().typeVar == typeVar) + return llvm::formatv("captures[{0}].getType()", it.index()).str(); } return None; } -static ScalarAssign * -findAssignment(StringRef name, SmallVectorImpl &assignments) { +static ScalarAssign *findAssignment(StringRef name, + std::vector &assignments) { for (auto &assign : assignments) { if (assign.arg == name) return &assign; @@ -428,7 +488,8 @@ {3} let arguments = (ins Variadic:$inputs, - Variadic:$outputs{4} + Variadic:$outputs, + Variadic:$captures{4} ); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); @@ -436,38 +497,45 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs), + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ValueRange", "{{}">:$captures), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); + $_state.addOperands(captures); $_state.addAttribute( "operand_segment_sizes", $_builder.getI32VectorAttr({{ static_cast(inputs.size()), - static_cast(outputs.size())})); + static_cast(outputs.size()), + static_cast(captures.size())})); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs), + captures); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs), + "ValueRange":$outputs, CArg<"ValueRange", "{{}">:$captures), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); + $_state.addOperands(captures); $_state.addTypes(resultTensorTypes); $_state.addAttribute( "operand_segment_sizes", $_builder.getI32VectorAttr({{ static_cast(inputs.size()), - static_cast(outputs.size())})); + static_cast(outputs.size()), + static_cast(captures.size())})); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs), + captures); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -482,7 +550,7 @@ ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + return ::parseNamedStructuredOp<{0}>(parser, result, true); }]; let hasFolder = 1; let hasCanonicalizer = 1; @@ -498,6 +566,7 @@ // Generic methods. static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); {7} }]; @@ -563,10 +632,10 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); - os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName, - opConfig.metadata->name, interfaceNameList, doc, attrList, - opConfig.structuredOp->args.size(), attrBuilder, - attrMethods); + os << llvm::formatv( + structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, + opConfig.metadata->name, interfaceNameList, doc, attrList, + opConfig.structuredOp->args.size(), attrBuilder, attrMethods); return success(); } @@ -578,7 +647,7 @@ return success(); raw_ostream &os = genContext.defns(); - StringRef className = opConfig.metadata->cppOpName; + StringRef className = opConfig.metadata->cppClassName; // Implementation banner. std::string bannerComment = llvm::formatv("Implementation of {0}", className); @@ -734,17 +803,25 @@ { // Generates a regionBuilder method. Parameters. // {0}: Class name - // {1}: Statements + // {1}: Number of args + // {2}: Number of captures + // {3}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( void {0}::regionBuilder(Block &block, ValueRange captures) {{ - RegionBuilderHelper helper(block); + assert({1} > 0 && block.getNumArguments() == {1} && + "{0} regionBuilder expects {1} (>=0) args"); + assert(captures.size() == {2} && + "{0} regionBuilder expects {2} captures"); + RegionBuilderHelper helper(block.getArgument(0).getContext(), block); SmallVector yields; - {1} + {3} helper.yieldOutputs(yields); } )FMT"; auto &args = opConfig.structuredOp->args; auto &assignments = opConfig.structuredOp->assignments; + Optional> captures = + opConfig.structuredOp->captures; size_t generatedAssignmentCount = 0; int localCounter = 0; SmallVector stmts; @@ -769,12 +846,37 @@ Optional argIndex = findTensorDefArgIndex(*expression.arg, args); if (!argIndex) { emitError(genContext.getLoc()) - << "scalar argument not defined on the op: " << arg.name; + << "scalar argument not defined on the op: " << *expression.arg; return None; } return std::string( llvm::formatv("block.getArgument({0})", *argIndex)); - } else if (expression.apply) { + } + if (expression.capture) { + // Capture reference. + Optional captureIndex = + findCaptureDefArgIndex(*expression.capture, captures); + if (!captureIndex) { + emitError(genContext.getLoc()) + << "catpure not defined on the op: " << *expression.capture; + } + return std::string(llvm::formatv("captures[{0}]", *captureIndex)); + } + if (expression.constant) { + std::string cppIdent = llvm::formatv("value{0}", ++localCounter); + stmts.push_back( + llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT", + cppIdent, expression.constant)); + return cppIdent; + } + if (expression.index) { + // Access an iteration index. + std::string cppIdent = llvm::formatv("value{0}", ++localCounter); + stmts.push_back(llvm::formatv("Value {0} = helper.index({1});", + cppIdent, *expression.index)); + return cppIdent; + } + if (expression.apply) { // Apply function. // Recursively generate operands. SmallVector operandCppValues; @@ -790,7 +892,8 @@ expression.apply->fnName, interleaveToString(operandCppValues, ", "))); return cppIdent; - } else if (expression.symbolicCast) { + } + if (expression.symbolicCast) { // Symbolic cast. // Operands must be arity 1. if (expression.symbolicCast->operands.size() != 1) { @@ -803,29 +906,23 @@ if (!operandCppValue) return None; - // Try to map the TypeVar to an arg index (which map to block arg - // indices), since we can just get that type directly. - // TODO: Handle free type variables which do not map to an argument. - Optional typeArgIndex = - findTypeVarArgIndex(expression.symbolicCast->typeVar, args); - if (!typeArgIndex) { + Optional typeCppValue = + findTypeValue(expression.symbolicCast->typeVar, args, captures); + if (!typeCppValue) { emitError(genContext.getLoc()) << "type variable " << expression.symbolicCast->typeVar - << ", used in a symbolic cast must map to an argument but it " - << "does not"; + << ", used in a symbolic cast must map to a predefined, " + << "argument, or capture type but it does not"; return None; } - std::string typeCppValue = - llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex); std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});", - cppIdent, typeCppValue, + cppIdent, typeCppValue.getValue(), *operandCppValue)); return cppIdent; - } else { - emitError(genContext.getLoc()) << "unknown ScalarExpression type"; - return None; } + emitError(genContext.getLoc()) << "unknown ScalarExpression type"; + return None; }; Optional cppValue = generateExpression(assignment->value); if (!cppValue) @@ -837,8 +934,11 @@ return emitError(genContext.getLoc()) << "mismatched number of assignments vs output arguments"; - os << llvm::formatv(structuredOpRegionBuilderFormat, className, - interleaveToString(stmts, "\n ")); + int64_t numOfArgs = args.size(); + int64_t numOfCaptures = + captures.hasValue() ? captures.getValue().size() : 0; + os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, + numOfCaptures, interleaveToString(stmts, "\n ")); } // Canonicalizers and folders. @@ -937,7 +1037,7 @@ } genContext.setLoc(NameLoc::get( - Identifier::get(opConfig.metadata->cppOpName, &mlirContext))); + Identifier::get(opConfig.metadata->cppClassName, &mlirContext))); if (failed(generateOp(opConfig, genContext))) { return 1; }