diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -169,6 +169,17 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults); +/// Returns the simplified affine map resulting from dropping the symbols that +/// do not appear in any of the individual maps in `affineMaps`. +/// Asserts that all maps in `affineMaps` are normalized to the same number of +/// dims and symbols. +/// Takes a callback `populateResult` to fill the `res` container with value +/// `m` at entry `idx`. This allows returning without worrying about ownership +/// considerations. +MLIR_CAPI_EXPORTED void mlirAffineMapCompressUnusedSymbols( + MlirAffineMap *affineMaps, intptr_t size, void *result, + void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -340,6 +340,11 @@ /// Drop the dims that are not used. AffineMap compressUnusedDims(AffineMap map); +/// Drop the dims that are not used by any of the individual maps in `maps`. +/// Asserts that all maps in `maps` are normalized to the same number of +/// dims and symbols. +SmallVector compressUnusedDims(ArrayRef maps); + /// Drop the dims that are not listed in `unusedDims`. AffineMap compressDims(AffineMap map, const llvm::SmallDenseSet &unusedDims); @@ -347,6 +352,11 @@ /// Drop the symbols that are not used. AffineMap compressUnusedSymbols(AffineMap map); +/// Drop the symbols that are not used by any of the individual maps in `maps`. +/// Asserts that all maps in `maps` are normalized to the same number of +/// dims and symbols. +SmallVector compressUnusedSymbols(ArrayRef maps); + /// Drop the symbols that are not listed in `unusedSymbols`. AffineMap compressSymbols(AffineMap map, const llvm::SmallDenseSet &unusedSymbols); diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -538,6 +538,23 @@ printAccum.parts.append(")"); return printAccum.join(); }) + .def_static("compress_unused_symbols", + [](py::list affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, + MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols( + maps.data(), maps.size(), compressed.data(), populate); + std::vector res; + for (auto m : compressed) + res.push_back(PyAffineMap(context->getRef(), m)); + return res; + }) .def_property_readonly( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -19,6 +19,13 @@ "emit_named_structured_op", ] +def isa(cls : Type, ty : Type): + try: + cls(ty) + return True + except ValueError: + return False + def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value): @@ -37,6 +44,8 @@ outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) + result_types = [t for t in out_types if isa(RankedTensorType, t)] + # Extract type vars for input/output based types. type_mapping = dict() # type: Dict[str, Type] for arg_def, arg_element_type in zip( @@ -48,30 +57,37 @@ # Emit the generic op. # TODO: Support emission of pure memref form. indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in op_config.indexing_maps]) + [AffineMapAttr.get(am) + # TODO: linalg verification does not currently allow symbols. + # Compress them for now. + for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + sparse_attr = ArrayAttr.get( + [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)]) + if len(sparse_attr) == 0: + sparse_attr = None - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, - type_mapping, indexing_maps_attr, iterator_types_attr) + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr = \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( - result_tensors=out_types, + result_tensors=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, iterator_types=iterator_types_attr, doc=None, # TODO: Make optional. library_call=None, # TODO: Make optional. - sparse=BoolAttr.get(False)) # TODO: Make optional. + sparse=sparse_attr) # TODO: Make optional. # Construct the body. block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) @@ -84,7 +100,7 @@ body_builder.assign(assignment) body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) - if len(out_arg_defs) == 1: + if len(result_types) == 1: return generic_op.result else: return generic_op.results @@ -95,8 +111,8 @@ op_class_name: str, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr = \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) # If we get here, there must exist a builtin class `op_class_name`. @@ -107,11 +123,16 @@ raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) linalgDialect = ctx.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, named_op.operation) + # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps + # attribute that the non-yaml path does not. The non-yaml path hardcodes the + # indexing_maps in C++ directly. + named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr + # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. - if len(out_arg_defs) == 1: + if len(result_types) == 1: return named_op.result else: return named_op.results diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -137,3 +137,14 @@ intptr_t numResults) { return wrap(unwrap(affineMap).getMinorSubMap(numResults)); } + +void mlirAffineMapCompressUnusedSymbols( + MlirAffineMap *affineMaps, intptr_t size, void *result, + void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) { + SmallVector maps; + for (intptr_t idx = 0; idx < size; ++idx) + maps.push_back(unwrap(affineMaps[idx])); + intptr_t idx = 0; + for (auto m : mlir::compressUnusedSymbols(maps)) + populateResult(result, idx++, wrap(m)); +} diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -543,6 +543,41 @@ return compressDims(map, unusedDims); } +static SmallVector +compressUnusedImpl(ArrayRef maps, + llvm::function_ref compressionFun) { + if (maps.empty()) + return SmallVector(); + SmallVector allExprs; + allExprs.reserve(maps.size() * maps.front().getNumResults()); + unsigned numDims = maps.front().getNumDims(), + numSymbols = maps.front().getNumSymbols(); + for (auto m : maps) { + assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && + "expected maps with same num dims and symbols"); + llvm::append_range(allExprs, m.getResults()); + } + AffineMap unifiedMap = compressionFun( + AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext())); + unsigned unifiedNumDims = unifiedMap.getNumDims(), + unifiedNumSymbols = unifiedMap.getNumSymbols(); + ArrayRef unifiedResults = unifiedMap.getResults(); + SmallVector res; + res.reserve(maps.size()); + for (auto m : maps) { + res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols, + unifiedResults.take_front(m.getNumResults()), + m.getContext())); + unifiedResults = unifiedResults.drop_front(m.getNumResults()); + } + return res; +} + +SmallVector mlir::compressUnusedDims(ArrayRef maps) { + return compressUnusedImpl(maps, + [](AffineMap m) { return compressUnusedDims(m); }); +} + AffineMap mlir::compressSymbols(AffineMap map, const llvm::SmallDenseSet &unusedSymbols) { @@ -576,6 +611,11 @@ return compressSymbols(map, unusedSymbols); } +SmallVector mlir::compressUnusedSymbols(ArrayRef maps) { + return compressUnusedImpl( + maps, [](AffineMap m) { return compressUnusedSymbols(m); }); +} + AffineMap mlir::simplifyAffineMap(AffineMap map) { SmallVector exprs; for (auto e : map.getResults()) { diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py @@ -37,9 +37,9 @@ # Note that these all have the same indexing maps. We verify the first and # then do more permutation tests on casting and body generation # behavior. - # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> - # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> - # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> + # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> + # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> # CHECK-LABEL: func @test_matmul_mono # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -94,6 +94,7 @@ init_result = linalg.InitTensorOp([4, 8], f32) # First check the named form with custom format # CHECK: linalg.matmul + # CHECK-NOT: linalg.memoized_indexing_maps # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) # CHECK-SAME: -> tensor<4x8xf32> @@ -118,7 +119,7 @@ # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : + # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result]) diff --git a/mlir/test/Bindings/Python/dialects/linalg/opsrun.py b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects/linalg/opsrun.py @@ -0,0 +1,105 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import sys +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std +from mlir.passmanager import * +from mlir.execution_engine import * + +# Log everything to stderr and flush so that we have a unified stream to match +# errors/info emitted by MLIR to stderr. +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + +boilerplate = """ +func @main() -> f32 attributes {llvm.emit_c_interface} { + %v0 = constant 0.0 : f32 + %v1 = constant 1.0 : f32 + %v2 = constant 2.0 : f32 + + %A = memref.alloc() : memref<4x16xf32> + %B = memref.alloc() : memref<16x8xf32> + %C = memref.alloc() : memref<4x8xf32> + linalg.fill(%A, %v1) : memref<4x16xf32>, f32 + linalg.fill(%B, %v2) : memref<16x8xf32>, f32 + linalg.fill(%C, %v0) : memref<4x8xf32>, f32 + + call @matmul_on_buffers(%A, %B, %C) : + (memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> () + + %c0 = constant 0 : index + %0 = memref.load %C[%c0, %c0] : memref<4x8xf32> + + // TODO: FFI-based solution to allow testing and printing with python code. + return %0 : f32 +} +""" + +def transform(module): + import mlir.conversions + import mlir.dialects.linalg.passes + import mlir.transforms + + # TODO: Allow cloning functions from one module to another. + # Atm we have to resort to string concatenation. + mod = Module.parse( + str(module.operation.regions[0].blocks[0].operations[0].operation) + + boilerplate) + pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," + + "convert-vector-to-llvm," + + "convert-std-to-llvm") + pm.run(mod) + return mod + +def test_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out]) + + execution_engine = ExecutionEngine(transform(module)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.) + execution_engine.invoke("main", res) + + log('RESULT: ', res[0]) + # CHECK: RESULT: 32.0 + +test_builtin() + +def test_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32)) + def matmul_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + + execution_engine = ExecutionEngine(transform(module)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.) + execution_engine.invoke("main", res) + + log('RESULT: ', res[0]) + # CHECK: RESULT: 32.0 + +test_generic()