diff --git a/mlir/docs/Tools/LinalgOpDsl.md b/mlir/docs/Tools/LinalgOpDsl.md --- a/mlir/docs/Tools/LinalgOpDsl.md +++ b/mlir/docs/Tools/LinalgOpDsl.md @@ -19,7 +19,7 @@ ```shell # Dump the `core_named_ops.py` module as YAML. -python -m python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops +python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops ``` The tool is meant for use during both development and runtime, but not as diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -11,21 +11,21 @@ - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: A usage: input shape: affine_map<()[s0, s1, s2] -> (s0, s2)> - element_type_var: T1 - - ! + type_var: T1 + - !LinalgOperandDefConfig name: B usage: input shape: affine_map<()[s0, s1, s2] -> (s2, s1)> - element_type_var: T2 - - ! + type_var: T2 + - !LinalgOperandDefConfig name: C usage: output shape: affine_map<()[s0, s1, s2] -> (s0, s1)> - element_type_var: U + type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -73,21 +73,21 @@ - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: A usage: input shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> - element_type_var: T1 - - ! + type_var: T1 + - !LinalgOperandDefConfig name: B usage: input shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> - element_type_var: T2 - - ! + type_var: T2 + - !LinalgOperandDefConfig name: C usage: output shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - element_type_var: U + type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> @@ -136,21 +136,21 @@ - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: A usage: input shape: affine_map<()[s0, s1] -> (s0, s1)> - element_type_var: T1 - - ! + type_var: T1 + - !LinalgOperandDefConfig name: y usage: input shape: affine_map<()[s0, s1] -> (s1)> - element_type_var: T2 - - ! + type_var: T2 + - !LinalgOperandDefConfig name: x usage: output shape: affine_map<()[s0, s1] -> (s0)> - element_type_var: U + type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -197,21 +197,21 @@ - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: y usage: input shape: affine_map<()[s0, s1] -> (s1)> - element_type_var: T1 - - ! + type_var: T1 + - !LinalgOperandDefConfig name: A usage: input shape: affine_map<()[s0, s1] -> (s1, s0)> - element_type_var: T2 - - ! + type_var: T2 + - !LinalgOperandDefConfig name: x usage: output shape: affine_map<()[s0, s1] -> (s0)> - element_type_var: U + type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d1)> @@ -258,21 +258,21 @@ - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: A usage: input shape: affine_map<()[s0] -> (s0)> - element_type_var: T1 - - ! + type_var: T1 + - !LinalgOperandDefConfig name: B usage: input shape: affine_map<()[s0] -> (s0)> - element_type_var: T2 - - ! + type_var: T2 + - !LinalgOperandDefConfig name: C usage: output shape: affine_map<()[s0] -> ()> - element_type_var: U + type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0)[s0] -> (d0)> @@ -319,18 +319,30 @@ and runs them in parallel. The seed operand and the indices of the data element seed the random number generation. The min and max operands limit the range of the generated random numbers. - - Note: The captures are hard-coded till there is capture support on the C++ - side. structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig + name: min + usage: input + type_var: F64 + - !LinalgOperandDefConfig + name: max + usage: input + type_var: F64 + - !LinalgOperandDefConfig + name: seed + usage: input + type_var: I32 + - !LinalgOperandDefConfig name: O usage: output shape: affine_map<()[s0, s1] -> (s0, s1)> - element_type_var: T + type_var: T indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> ()> + - affine_map<(d0, d1)[s0, s1] -> ()> + - affine_map<(d0, d1)[s0, s1] -> ()> - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> iterator_types: - parallel @@ -401,11 +413,7 @@ - !ScalarExpression scalar_index: 0 - !ScalarExpression - symbolic_cast: - type_var: I32 - operands: - - !ScalarExpression - scalar_const: '42 : i64' + scalar_arg: seed - !ScalarExpression symbolic_cast: type_var: I32 @@ -439,17 +447,9 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: - type_var: F64 - operands: - - !ScalarExpression - scalar_const: '1000 : i64' + scalar_arg: max - !ScalarExpression - symbolic_cast: - type_var: F64 - operands: - - !ScalarExpression - scalar_const: '-1000 : i64' + scalar_arg: min - !ScalarExpression symbolic_cast: type_var: F64 @@ -457,8 +457,4 @@ - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' - !ScalarExpression - symbolic_cast: - type_var: F64 - operands: - - !ScalarExpression - scalar_const: '-1000 : i64' + scalar_arg: min diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -30,16 +30,13 @@ // ----- -func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> +func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } // CHECK-LABEL: @generalize_fill_rng_2d_f32 -// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>) -// CHECK-DAG: %[[MIN:.+]] = constant -1000 : i64 -// CHECK-DAG: %[[MAX:.+]] = constant 1000 : i64 -// CHECK-DAG: %[[SEED:.+]] = constant 42 : i32 +// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: f32 // CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index // CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index // CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 @@ -50,27 +47,24 @@ // CHECK-DAG: %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32 // CHECK-DAG: %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32 // Skip random number computation for the second index. -// CHECK-DAG: %[[MIN_CAST1:.+]] = sitofp %[[MIN]] : i64 to f64 -// CHECK-DAG: %[[MAX_CAST:.+]] = sitofp %[[MAX]] : i64 to f64 -// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX_CAST]], %[[MIN_CAST1]] : f64 +// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 // CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64 // CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64 // CHECK-DAG: %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64 -// CHECK-DAG: %[[MIN_CAST2:.+]] = sitofp %[[MIN]] : i64 to f64 -// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN_CAST2]] : f64 +// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN]] : f64 // CHECK-DAG: %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32 // CHECK-NEXT: linalg.yield %[[VAL6]] : f32 // CHECK-NEXT: -> tensor<16x32xf32> // ----- -func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32> +func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } // CHECK-LABEL: @generalize_fill_rng_2d_i32 -// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>) +// CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: i32 // Verifies floating point to integer cast. // CHECK: %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32 // CHECK-NEXT: linalg.yield %[[VAL6]] : i32 diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -19,11 +19,11 @@ Detailed description. structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig name: O usage: output shape: affine_map<()[s0, s1] -> (s0, s1)> - element_type_var: T + type_var: T indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -58,7 +58,7 @@ # ODS-NEXT: }]; # ODS: let arguments = -# ODS-NEXT: Variadic:$inputs, +# ODS-NEXT: Variadic:$inputs, # ODS-NEXT: Variadic:$outputs # ODS: let builders = @@ -103,18 +103,23 @@ Detailed description. structured_op: !LinalgStructuredOpConfig args: - - ! + - !LinalgOperandDefConfig + name: value + usage: input + type_var: T + - !LinalgOperandDefConfig name: I usage: input - shape: affine_map<()[s0, s1] -> (s0, s1)> - element_type_var: T - - ! + shape: affine_map<()[s0, s1] -> (s1, s0)> + type_var: T + - !LinalgOperandDefConfig name: O usage: output shape: affine_map<()[s0, s1] -> (s0, s1)> - element_type_var: T + type_var: T indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> ()> - affine_map<(d0, d1)[s0, s1] -> (d1, d0)> - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> iterator_types: @@ -124,15 +129,23 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_arg: I + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: value + - !ScalarExpression + scalar_arg: I # IMPL-LABEL: Test2Op::iterator_types() # IMPL-NEXT: { getParallelIteratorTypeName(), getParallelIteratorTypeName() } # IMPL: Test2Op::indexing_maps() +# IMPL: "affine_map<(d0, d1)[s0, s1] -> ()>" # IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>" # IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>" # IMPL: void Test2Op::regionBuilder( # IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) -# IMPL: yields.push_back(block.getArgument(0)); + +# IMPL: = helper.applyfn__add(block.getArgument(0), block.getArgument(1)); diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py --- a/mlir/test/python/dialects/linalg/opsrun.py +++ b/mlir/test/python/dialects/linalg/opsrun.py @@ -131,6 +131,33 @@ test_matmul_generic() +def test_fill_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) + def fill_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out]) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -480 + + +test_fill_builtin() + + def test_fill_generic(): with Context() as ctx, Location.unknown(): module = Module.create() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -62,17 +62,13 @@ AffineMap affineMap() { return affineMapAttr.getValue(); } }; -enum class LinalgTensorUsageDef { - input, - output, - temporary, -}; +enum class LinalgOperandDefUsage { input, output }; -struct LinalgTensorDef { +struct LinalgOperandDef { std::string name; - LinalgTensorUsageDef usage; - SerializedAffineMap shape; - std::string elementTypeVar; + LinalgOperandDefUsage usage; + Optional shape; + std::string typeVar; }; enum class LinalgIteratorTypeDef { @@ -114,10 +110,10 @@ }; struct LinalgStructuredOpConfig { - SmallVector args; + SmallVector args; LinalgIndexingMapsConfig indexingMaps; SmallVector iteratorTypes; - SmallVector assignments; + std::vector assignments; }; struct LinalgOpConfig { @@ -131,7 +127,7 @@ // Mapping traits. //===----------------------------------------------------------------------===// -LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef) +LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) @@ -153,8 +149,8 @@ }; /// A structured op models (at most) a single contraction by modeling -/// - A list of named arguments (`LinalgTensorDef`), which can be inputs, -/// outputs, or temporaries. +/// - A list of named arguments (`LinalgOperandDef`), which can be inputs or +/// outputs. /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). @@ -168,31 +164,30 @@ } }; -/// Maps a named tensor-argument to an operation, consisting of: +/// Maps a named tensor- or scalar-argument to an operation, consisting of: /// - `name`: Must be unique within the operation. /// - `usage`: How the argument is used (input, output, etc). -/// - `shape`: An AffineMap from all op symbols to the specific shape -/// of this argument. Each shape must be normalized over the same list of -/// symbols and have no dimension inputs. -/// - `element_type_var`: The symbolic type variable that binds to the scalar -/// element type of this TensorDef. +/// - `shape`: An optional AffineMap from all op symbols to the shape of the +/// argument. Only tensor-arguments have a shape. Each shape must be +/// normalized over the same list of symbols and have no dimension inputs. +/// - `type_var`: The symbolic type variable that binds to the element or self +/// type of the tensor- or scalar-argument, respectively. template <> -struct MappingTraits { - static void mapping(IO &io, LinalgTensorDef &info) { +struct MappingTraits { + static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); - io.mapRequired("shape", info.shape); - io.mapRequired("element_type_var", info.elementTypeVar); + io.mapOptional("shape", info.shape); + io.mapRequired("type_var", info.typeVar); } }; /// Usage enum for a named argument. template <> -struct ScalarEnumerationTraits { - static void enumeration(IO &io, LinalgTensorUsageDef &value) { - io.enumCase(value, "input", LinalgTensorUsageDef::input); - io.enumCase(value, "output", LinalgTensorUsageDef::output); - io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary); +struct ScalarEnumerationTraits { + static void enumeration(IO &io, LinalgOperandDefUsage &value) { + io.enumCase(value, "input", LinalgOperandDefUsage::input); + io.enumCase(value, "output", LinalgOperandDefUsage::output); } }; @@ -229,7 +224,7 @@ }; /// Models an assignment to a named output. -/// - The `arg` name must match a named output or temporary. +/// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). template <> @@ -366,7 +361,7 @@ } static Optional -findTensorDefArgIndex(StringRef name, SmallVectorImpl &args) { +findTensorDefArgIndex(StringRef name, SmallVectorImpl &args) { for (auto it : llvm::enumerate(args)) { if (it.value().name == name) return it.index(); @@ -376,7 +371,7 @@ // Try to map the TypeVar to a predefined or an argument type. static Optional -findTypeValue(StringRef typeVar, SmallVectorImpl &args) { +findTypeValue(StringRef typeVar, SmallVectorImpl &args) { // Handle all predefined types. if (typeVar == "I32") return std::string("helper.getIntegerType(32)"); @@ -389,7 +384,7 @@ // Search all argument types. for (auto it : llvm::enumerate(args)) { - if (it.value().elementTypeVar == typeVar) + if (it.value().typeVar == typeVar) return llvm::formatv("block.getArgument({0}).getType()", it.index()) .str(); } @@ -397,8 +392,8 @@ return None; } -static ScalarAssign * -findAssignment(StringRef name, SmallVectorImpl &assignments) { +static ScalarAssign *findAssignment(StringRef name, + std::vector &assignments) { for (auto &assign : assignments) { if (assign.arg == name) return &assign; @@ -445,7 +440,7 @@ /*extraInterfaces=*/[{2}])> { {3} let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, Variadic:$outputs{4} ); let results = (outs Variadic:$result_tensors); @@ -467,7 +462,7 @@ $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -485,7 +480,7 @@ $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -500,7 +495,7 @@ ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + return ::parseNamedStructuredOp<{0}>(parser, result); }]; let hasFolder = 1; @@ -768,9 +763,8 @@ size_t generatedAssignmentCount = 0; int localCounter = 0; SmallVector stmts; - for (LinalgTensorDef &arg : args) { - if (arg.usage != LinalgTensorUsageDef::output && - arg.usage != LinalgTensorUsageDef::temporary) + for (LinalgOperandDef &arg : args) { + if (arg.usage != LinalgOperandDefUsage::output) continue; // Find the assignment that correlates with the argument.