diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -58,8 +58,6 @@ private: llvm::StringMap namedStructuredOpRegionBuilders; }]; - - let emitAccessorPrefix = kEmitAccessorPrefix_Both; } // Define the function attribute enums matching the OpDSL functions. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -771,8 +771,9 @@ // TODO: reevalute the need for a cast when a better mechanism exists. //========================================================================// - ValueRange inputs() { - return cast(*this->getOperation()).inputs(); + ValueRange getInputs() { + return cast(*this->getOperation()) + .getInputs(); } int64_t getNumInputs() { @@ -780,7 +781,7 @@ .getNumInputs(); } - ValueRange outputs() { + ValueRange getOutputs() { return cast(*this->getOperation()) .getOutputs(); } @@ -922,7 +923,7 @@ // The 'DestinationStyleOpInterface' provides access to the methods relevant // for destination-style ops. A destination-style operation has 'n' input // arguments and 'm' output arguments. Each op that wants to implement -// DestinationStyleOpInterface needs to define inputs() and getOutputs() +// DestinationStyleOpInterface needs to define getInputs() and getOutputs() // methods. def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { let cppNamespace = "::mlir::linalg"; @@ -930,18 +931,18 @@ //===------------------------------------------------------------------===// // Num input/output arguments handling. //===------------------------------------------------------------------===// - // `inputs` must be defined by each op that wants to implement the + // `getInputs` must be defined by each op that wants to implement the // DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ Return the input shape operands. }], /*retTy=*/"ValueRange", - /*methodName=*/"inputs", + /*methodName=*/"getInputs", /*args=*/(ins) >, - // These special methods rely on `inputs` and `outputs` being defined by - // each op that wants to implement the DestinationStyleOpInterface. + // These special methods rely on `getInputs` and `getOutputs` being defined + // by each op that wants to implement the DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ Return the number of inputs. diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -360,7 +360,7 @@ SmallVector initOrAllocTensorOps; SmallVector fillOps; fillOps.reserve(op.getNumOutputs()); - for (auto it : llvm::zip(op.outputs(), neutralElements)) { + for (auto it : llvm::zip(op.getOutputs(), neutralElements)) { Value rankedTensor = std::get<0>(it); auto t = rankedTensor.getType().cast(); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( @@ -403,7 +403,7 @@ // Step 3. Handle operands. // Compute the new input tensors. - auto newInputs = llvm::to_vector<4>(op.inputs()); + auto newInputs = llvm::to_vector<4>(op.getInputs()); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. newInputs.push_back(b.create( @@ -433,7 +433,7 @@ // multi-reduction support is available. SmallVector results; for (auto it : - llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) { + llvm::zip(genericOp->getResults(), op.getOutputs(), combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it); auto originalOutputType = originalOutput.getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1334,9 +1334,9 @@ // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) return; - lhsShaped = linalgOp.inputs()[0]; - rhsShaped = linalgOp.inputs()[1]; - resShaped = linalgOp.outputs()[0]; + lhsShaped = linalgOp.getInputs()[0]; + rhsShaped = linalgOp.getInputs()[1]; + resShaped = linalgOp.getOutputs()[0]; lhsShapedType = lhsShaped.getType().dyn_cast(); rhsShapedType = rhsShaped.getType().dyn_cast(); resShapedType = resShaped.getType().dyn_cast(); diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -167,9 +167,9 @@ # ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes(); # IMPL: getSymbolBindings(Test2Op self) -# IMPL: cst2 = self.strides().getValues()[0]; +# IMPL: cst2 = self.getStrides().getValues()[0]; # IMPL-NEXT: getAffineConstantExpr(cst2, context) -# IMPL: cst3 = self.strides().getValues()[1]; +# IMPL: cst3 = self.getStrides().getValues()[1]; # IMPL-NEXT: getAffineConstantExpr(cst3, context) # IMPL: Test2Op::getIndexingMaps() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -670,7 +670,7 @@ {0}::getNumRegionArgs(), {0}::getRegionBuilder()); } void {0}::print(OpAsmPrinter &p) {{ - ::printNamedStructuredOp(p, getOperation(), inputs(), outputs()); + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); } )FMT"; @@ -857,7 +857,7 @@ // {1}: Symbol position // {2}: Attribute index static const char structuredOpAccessAttrFormat[] = R"FMT( -int64_t cst{1} = self.{0}().getValues()[{2}]; +int64_t cst{1} = self.get{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; // Update all symbol bindings mapped to an attribute. @@ -868,8 +868,10 @@ for (auto &en : llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { if (auto symbol = en.value().dyn_cast()) { + std::string argName = arg.name; + argName[0] = toupper(argName[0]); symbolBindings[symbol.getPosition()] = - llvm::formatv(structuredOpAccessAttrFormat, arg.name, + llvm::formatv(structuredOpAccessAttrFormat, argName, symbol.getPosition(), en.index()); } }