diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -16,6 +16,12 @@ "emit_named_structured_op", ] +def isa(cls : Type, ty : Type): + try: + cls(ty) + return True + except ValueError: + return False def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, @@ -35,6 +41,8 @@ outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) + result_types = [t for t in out_types if isa(RankedTensorType, t)] + # Extract type vars for input/output based types. type_mapping = dict() # type: Dict[str, Type] for arg_def, arg_element_type in zip( @@ -50,19 +58,19 @@ iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( - result_tensors=out_types, + result_tensors=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, @@ -82,7 +90,7 @@ body_builder.assign(assignment) body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) - if len(out_arg_defs) == 1: + if len(result_types) == 1: return generic_op.result else: return generic_op.results @@ -93,7 +101,7 @@ op_class_name: str, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) @@ -101,8 +109,8 @@ raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - named_op = getattr(linalg, op_class_name)(ins, outs, out_types) - if len(out_arg_defs) == 1: + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + if len(result_types) == 1: return named_op.result else: return named_op.results diff --git a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py @@ -0,0 +1,95 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import sys +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std +from mlir.passmanager import * +from mlir.execution_engine import * + +# Log everything to stderr and flush so that we have a unified stream to match +# errors/info emitted by MLIR to stderr. +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + +boilerplate = """ +func @main() -> f32 attributes {llvm.emit_c_interface} { + %v0 = constant 0.0 : f32 + %v1 = constant 1.0 : f32 + %v2 = constant 2.0 : f32 + + %A = memref.alloc() : memref<4x16xf32> + %B = memref.alloc() : memref<16x8xf32> + %C = memref.alloc() : memref<4x8xf32> + linalg.fill(%A, %v1) : memref<4x16xf32>, f32 + linalg.fill(%B, %v2) : memref<16x8xf32>, f32 + linalg.fill(%C, %v0) : memref<4x8xf32>, f32 + + call @matmul_on_buffers(%A, %B, %C) : + (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () + + // TODO: Link stuff in with python. + // %res = memref.cast %C: memref<4x8xf32> to memref<*xf32> + // call @print_memref_f32(%res) : (memref<*xf32>) -> () + + %c0 = constant 0 : index + %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> + // TODO: Link stuff in with python. + // vector.print %0: f32 + + return %0 : f32 +} + +// TODO: Link stuff in with python. +// func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } +""" + +def transform(module): + import mlir.conversions + import mlir.dialects.linalg.passes + import mlir.transforms + pm = PassManager.parse("func(convert-linalg-to-loops)," + + "convert-scf-to-std," + + "convert-vector-to-llvm," + + "convert-std-to-llvm") + # TODO: For now, roundtrip through strings for now because linalg IR is invalid + # atm but prints correctly. + # TODO: invalid linalg IR seems to also break the ability to print the first + # operation, so we hack the string for now. + # print(module.body.operations[0]) + # print(module.body.operations[0].entry_block) + program = str(module)[10:-3] + boilerplate + mod = Module.parse(program) + pm.run(mod) + # print(mod) + return mod + +def test(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_on_buffers(lhs, rhs, out): + # TODO: body is not populated. + linalg.matmul(lhs, rhs, outs=[out]) + # TODO: This does not work because the indexing_map symbols have not been + # simplified away and the linalg.generic fails verification. + # linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + + execution_engine = ExecutionEngine(transform(module)) + + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.) + execution_engine.invoke("main", res) + + log('RESULT: ', res[0]) + # CHECK: RESULT: 32.0 + +test() \ No newline at end of file