diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -46,6 +46,7 @@ let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; let hasConstantMaterializer = 1; + let usePropertiesForAttributes = 1; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -27,8 +27,10 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -128,10 +130,15 @@ SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes, bool addOperandSegmentSizes = true) { - SMLoc inputsOperandsLoc, outputsOperandsLoc; + SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc; SmallVector inputsOperands, outputsOperands; + if (succeeded(parser.parseOptionalLess())) { + if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater()) + return failure(); + } + attrsLoc = parser.getCurrentLocation(); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -159,10 +166,35 @@ return failure(); if (addOperandSegmentSizes) { - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getDenseI32ArrayAttr( - {static_cast(inputsOperands.size()), - static_cast(outputsOperands.size())})); + // This is a bit complex because we're trying to be backward compatible with + // operation syntax that mix the inherent attributes and the discardable ones + // in the same dictionary. + // If the properties are used, we append the operand_segment_sizes there directly. + // Otherwise we append it to the discardable attributes dictionary where it is + // handled by the generic Operation::create(...) method. + if (result.propertiesAttr) { + NamedAttrList attrs = result.propertiesAttr.cast(); + attrs.append("operand_segment_sizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + result.propertiesAttr = attrs.getDictionary(parser.getContext()); + } else { + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + } + } + if (!result.propertiesAttr) { + Optional info = result.name.getRegisteredInfo(); + if (info) { + if (failed(info->verifyInherentAttrs(result.attributes, [&]() { + return parser.emitError(attrsLoc) + << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + } } return success(); } diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -163,6 +163,16 @@ // ----- +func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type_properties(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { + // expected-error @+1 {{invalid properties {dilations = dense<1> : vector<2xi64>, operand_segment_sizes = array, strides = dense<2.000000e+00> : vector<2xf32>} for op linalg.depthwise_conv_2d_nhwc_hwc: Invalid attribute `strides` in property conversion: dense<2.000000e+00> : vector<2xf32>}} + linalg.depthwise_conv_2d_nhwc_hwc <{dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}> + ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) + outs(%output: memref<1x56x56x96xf32>) + return +} + +// ----- + func.func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}} linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}