diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -187,7 +187,11 @@ if arg_def.operand_def.kind == OperandKind.SCALAR: indexing_maps.append(scalar_map) if arg_def.operand_def.is_tensor(): - indexing_maps.append(tensor_map) + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -320,3 +320,18 @@ // CHECK-LABEL: @generalize_elemwise_mul // CHECK: = arith.mulf + +// ----- + +// Verifies pointwise ops support rank zero input tensors +func @generalize_elemwise_rank_zero(%lhs : tensor, %rhs : tensor, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%lhs, %rhs: tensor, tensor) + outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_rank_zero +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK: = arith.subf diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -15,6 +15,9 @@ def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): O[None] = TypeFn.cast_signed(U, value) +@linalg_structured_op +def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)): + O[None] = TypeFn.cast_signed(U, I[None]) with Context() as ctx, Location.unknown(): module = Module.create() @@ -25,6 +28,8 @@ # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()> + # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> # CHECK-LABEL: @test_fill_0d # CHECK: linalg.generic @@ -42,5 +47,13 @@ def test_fill_2d(value, init_result): return fill_poly(value, outs=[init_result]) + # CHECK-LABEL: @test_fill_rank_zero_3d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + @builtin.FuncOp.from_py_func( + RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)) + def test_fill_rank_zero_3d(input, init_result): + return fill_rank_zero_poly(input, outs=[init_result]) print(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -25,19 +25,19 @@ %v1 = arith.constant 1.0 : f32 %v2 = arith.constant 2.0 : f32 - %lhs = memref.alloc() : memref<4x8xf32> + %lhs = memref.alloc() : memref %rhs = memref.alloc() : memref<4x8xf32> %O0 = memref.alloc() : memref<4x8xf32> %O1 = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %lhs) : f32, memref<4x8xf32> + linalg.fill(%v1, %lhs) : f32, memref linalg.fill(%v2, %rhs) : f32, memref<4x8xf32> linalg.fill(%v0, %O0) : f32, memref<4x8xf32> linalg.fill(%v0, %O1) : f32, memref<4x8xf32> call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : - (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + (memref, memref<4x8xf32>, memref<4x8xf32>) -> () call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : - (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + (memref, memref<4x8xf32>, memref<4x8xf32>) -> () %c0 = arith.constant 0 : index %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> @@ -212,14 +212,14 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out]) linalg.elemwise_binary(out, rhs, outs=[out]) @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) @@ -251,14 +251,14 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary( diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -672,8 +672,10 @@ AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( getNumParallelLoops(), context); SmallVector indexingMaps; - for (OpOperand *opOperand : getInputAndOutputOperands()) - indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap); + for (OpOperand *opOperand : getInputOperands()) + indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap); + for (OpOperand *opOperand : getOutputOperands()) + indexingMaps.push_back(tensorMap); return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); } )FMT";