diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1713,14 +1713,11 @@ } }]; - // TODO(b/144779634, ravishankarm) : Use different arguments for - // offsets, sizes and strides. let arguments = (ins AnyMemRef:$source, Variadic:$offsets, Variadic:$sizes, - Variadic:$strides, - I32ElementsAttr:$operand_segment_sizes + Variadic:$strides ); let results = (outs AnyMemRef:$result); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -713,8 +713,7 @@ Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>, Arguments<(ins AnyVector:$vector, Variadic:$input_shape, Variadic:$output_shape, - I64ArrayAttr:$fixed_vector_sizes, - I32ElementsAttr:$operand_segment_sizes)>, + I64ArrayAttr:$fixed_vector_sizes)>, Results<(outs AnyVector:$result)> { let summary = "vector reshape operation"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1825,10 +1825,7 @@ ArrayRef attrs) { if (!resultType) resultType = inferSubViewResultType(source.getType().cast()); - auto segmentAttr = b->getI32VectorAttr( - {1, static_cast(offsets.size()), static_cast(sizes.size()), - static_cast(strides.size())}); - build(b, result, resultType, source, offsets, sizes, strides, segmentAttr); + build(b, result, resultType, source, offsets, sizes, strides); result.addAttributes(attrs); } diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -45,7 +45,7 @@ } // CHECK-LABEL: OpD definitions -// CHECK: void OpD::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef attributes) +// CHECK: void OpD::build(Builder *odsBuilder, OperationState &odsState, ValueRange operands, ArrayRef attributes) // CHECK: odsState.addTypes({attr.second.cast().getValue()}); def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { @@ -54,7 +54,7 @@ } // CHECK-LABEL: OpE definitions -// CHECK: void OpE::build(Builder *, OperationState &odsState, ValueRange operands, ArrayRef attributes) +// CHECK: void OpE::build(Builder *odsBuilder, OperationState &odsState, ValueRange operands, ArrayRef attributes) // CHECK: odsState.addTypes({attr.second.getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -20,6 +20,7 @@ #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -737,7 +738,7 @@ // Signature std::string params = - std::string("Builder *, OperationState &") + builderOpState + + std::string("Builder *odsBuilder, OperationState &") + builderOpState + ", ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); @@ -804,7 +805,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() { std::string params = - std::string("Builder *, OperationState &") + builderOpState + + std::string("Builder *odsBuilder, OperationState &") + builderOpState + ", ValueRange operands, ArrayRef attributes"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); @@ -1062,6 +1063,20 @@ << ");\n"; } + // If the operation has the operand segment size attribute, add it here. + if (op.getTrait("OpTrait::AttrSizedOperandSegments")) { + body << " " << builderOpState + << ".addAttribute(\"operand_segment_sizes\", " + "odsBuilder->getI32VectorAttr({"; + interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + if (op.getOperand(i).isVariadic()) + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + else + body << "1"; + }); + body << "}));\n"; + } + // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr;