diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -562,7 +562,8 @@ doc = "C(m, n) += A(m, k) * B(k, n)", indexing_maps = #matmul_accesses, library_call = "linalg_matmul", - n_views = [2, 1], + args_in = 2, + args_out = 1, iterator_types = ["parallel", "parallel", "reduction"] } ``` @@ -634,7 +635,7 @@ let builders = [ OpBuilder< "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " - "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ValueRange args, int64_t argsIn, int64_t argsOut, " "ArrayRef indexingMaps, ArrayRef iteratorTypes, " "function_ref = nullptr"> ]; @@ -689,7 +690,8 @@ doc = "C(m, n) += A(m, k) * B(k, n)", indexing_maps = #matmul_accesses, library_call = "linalg_matmul", - n_views = [2, 1], + args_in = 2, + args_out = 1, iterator_types = ["parallel", "parallel", "reduction"] } ``` @@ -768,7 +770,7 @@ let builders = [ OpBuilder< "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " - "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ValueRange args, int64_t argsIn, int64_t argsOut, " "ArrayRef indexingMaps, ArrayRef iteratorTypes, " "function_ref " "= nullptr"> diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -72,12 +72,11 @@ void GenericOp::build( OpBuilder &builder, OperationState &result, ArrayRef resultTypes, - ValueRange args, int64_t inputCount, int64_t outputCount, + ValueRange args, int64_t argsIn, int64_t argsOut, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild) { - build(builder, result, resultTypes, args, - builder.getI64IntegerAttr(inputCount), - builder.getI64IntegerAttr(outputCount), + build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn), + builder.getI64IntegerAttr(argsOut), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), /*doc=*/nullptr, /*library_call=*/nullptr); @@ -96,13 +95,12 @@ void IndexedGenericOp::build( OpBuilder &builder, OperationState &result, ArrayRef resultTypes, - ValueRange args, int64_t inputCount, int64_t outputCount, + ValueRange args, int64_t argsIn, int64_t argsOut, ArrayRef indexingMaps, ArrayRef iteratorTypes, function_ref bodyBuild) { - build(builder, result, resultTypes, args, - builder.getI64IntegerAttr(inputCount), - builder.getI64IntegerAttr(outputCount), + build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn), + builder.getI64IntegerAttr(argsOut), builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), /*doc=*/nullptr, /*library_call=*/nullptr);