diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -35,6 +35,10 @@ outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) + # TODO: This hack only works on buffers + result_types = [] + # result_types = [t for t in out_types if t == RankedTensorType.get(t.shape())] + # Extract type vars for input/output based types. type_mapping = dict() # type: Dict[str, Type] for arg_def, arg_element_type in zip( @@ -50,19 +54,19 @@ iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( - result_tensors=out_types, + result_tensors=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, @@ -93,7 +97,7 @@ op_class_name: str, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) @@ -101,8 +105,8 @@ raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - named_op = getattr(linalg, op_class_name)(ins, outs, out_types) - if len(out_arg_defs) == 1: + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + if len(result_types) == 1: return named_op.result else: return named_op.results diff --git a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py --- a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py +++ b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py @@ -8,13 +8,25 @@ from mlir.passmanager import * from mlir.execution_engine import * -def lowerToLLVM(module): +def transform(module): + # print(module) + # print(module.body.operations[0].print(print_generic_op_form=True)) + # print(module.body.operations[0].entry_block.operations[0].regions[0]) + import mlir.conversions import mlir.transforms - pm = PassManager.parse("print-op-stats") + pm = PassManager.parse("print-op-stats, print-op-stats") #pm = PassManager.parse("module(func(print-op-stats))") + #pm = PassManager.parse("convert-std-to-llvm") #pm = PassManager.parse("convert-linalg-to-loops") - pm.run(module) + + # Roundtrip through strings for now because lianlg IR is invalid atm but + # prints correctly. + + print(module) + pm.run(Module.parse(module.__str__())) + # pm.run(module) + return module def test(): @@ -23,19 +35,18 @@ f32 = F32Type.get() memref_type = MemRefType.get((2, 3, 4), f32) with InsertionPoint(module.body): - # @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32), - # MemRefType.get((16, 8), f32), - # MemRefType.get((4, 8), f32)) - # def matmul_on_buffers(lhs, rhs, out): - # linalg.matmul(lhs, rhs, outs=[out]) - - @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), - RankedTensorType.get((16, 8), f32)) - def named_form(lhs, rhs): - return linalg.matmul(lhs, rhs, outs=[rhs], emit_generic=True) + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out]) + + # @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), + # RankedTensorType.get((16, 8), f32)) + # def named_form(lhs, rhs): + # return linalg.matmul(lhs, rhs, outs=[rhs], emit_generic=False) - print(module) - print(lowerToLLVM(module)) + transform(module) # execution_engine = ExecutionEngine(lowerToLLVM(module)) # # Prepare arguments: two input floats and one result.