diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/Trait.h" namespace llvm { @@ -58,6 +59,9 @@ /// Get the parameter name. StringRef getName() const; + /// Get the parameter accessor name. + std::string getAccessorName() const; + /// If specified, get the custom allocator code for this parameter. Optional getAllocator() const; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -147,8 +147,8 @@ cast(op).quantization_info()) { auto quantizationInfo = cast(op).quantization_info(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - int64_t inZp = quantizationInfo.getValue().getInput_zp(); - int64_t outZp = quantizationInfo.getValue().getOutput_zp(); + int64_t inZp = quantizationInfo.getValue().getInputZp(); + int64_t outZp = quantizationInfo.getValue().getOutputZp(); // Compute the maximum value that can occur in the intermediate buffer. int64_t zpAdd = inZp + outZp; @@ -1847,7 +1847,7 @@ } else if (elementTy.isa() && !padOp.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (elementTy.isa() && padOp.quantization_info()) { - int64_t value = padOp.quantization_info().getValue().getInput_zp(); + int64_t value = padOp.quantization_info().getValue().getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -202,7 +202,7 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - int64_t iZp = quantizationInfo.getInput_zp(); + int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -274,8 +274,8 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); - auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); + auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); + auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); @@ -366,8 +366,8 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); - kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); + iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); + kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); } auto weightShape = weightTy.getShape(); @@ -378,7 +378,7 @@ if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - int64_t iZp = quantizationInfo.getInput_zp(); + int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -542,9 +542,9 @@ auto quantizationInfo = op.quantization_info().getValue(); auto aZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp())); auto bZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp())); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); @@ -652,9 +652,9 @@ auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp())); auto outputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp())); Value matmul = rewriter .create( @@ -892,8 +892,7 @@ if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, - b.getIntegerAttr(accETy, quantizationInfo.getInput_zp())); + loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); Value offset = rewriter.create(loc, accETy, countI, inputZp); poolVal = @@ -930,7 +929,7 @@ auto quantizationInfo = op.quantization_info().getValue(); auto outputZp = rewriter.create( loc, b.getIntegerAttr(scaled.getType(), - quantizationInfo.getOutput_zp())); + quantizationInfo.getOutputZp())); scaled = rewriter.create(loc, scaled, outputZp) .getResult(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -145,7 +145,7 @@ DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) { if (auto entryPoint = spirv::lookupEntryPointABI(op)) - return entryPoint.getLocal_size(); + return entryPoint.getLocalSize(); return {}; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -135,7 +135,7 @@ funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars); // Specifies the spv.ExecutionModeOp. - auto localSizeAttr = entryPointAttr.getLocal_size(); + auto localSizeAttr = entryPointAttr.getLocalSize(); if (localSizeAttr) { auto values = localSizeAttr.getValues(); SmallVector localSize(values); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -347,7 +347,7 @@ } else if (elementTy.isa() && !op.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (elementTy.isa() && op.quantization_info()) { - auto value = op.quantization_info().getValue().getInput_zp(); + auto value = op.quantization_info().getValue().getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -214,7 +214,7 @@ weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getWeight_zp())); + rewriter.getAttr(quantInfo.getWeightZp())); } else { weight = createOpAndInfer(rewriter, loc, @@ -278,7 +278,7 @@ input = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getInput_zp())); + rewriter.getAttr(quantInfo.getInputZp())); } else { input = createOpAndInfer(rewriter, loc, UnrankedTensorType::get(inputETy), diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -215,6 +215,11 @@ return def->getArgName(index)->getValue(); } +std::string AttrOrTypeParameter::getAccessorName() const { + return "get" + + llvm::convertToCamelFromSnakeCase(getName(), /*capitalizeFirst=*/true); +} + Optional AttrOrTypeParameter::getAllocator() const { return getDefValue("allocator"); } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -7,16 +7,12 @@ //===----------------------------------------------------------------------===// #include "AttrOrTypeFormatGen.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/TableGen/Error.h" @@ -31,13 +27,6 @@ // Utility Functions //===----------------------------------------------------------------------===// -std::string mlir::tblgen::getParameterAccessorName(StringRef name) { - assert(!name.empty() && "parameter has empty name"); - auto ret = "get" + name.str(); - ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name - return ret; -} - /// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// specified and can only find one dialect's defs, use that. static void collectAllDefs(StringRef selectedDialect, @@ -288,7 +277,7 @@ void DefGen::emitAccessors() { for (auto ¶m : params) { Method *m = defCls.addMethod( - param.getCppAccessorType(), getParameterAccessorName(param.getName()), + param.getCppAccessorType(), param.getAccessorName(), def.genStorageClass() ? Method::Const : Method::ConstDeclaration); // Generate accessor definitions only if we also generate the storage // class. Otherwise, let the user define the exact accessor definition. diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -20,11 +20,6 @@ void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser, MethodBody &printer); -/// From the parameter name, get the name of the accessor function in camelcase. -/// The first letter of the parameter is upper-cased and prefixed with "get". -/// E.g. 'value' -> 'getValue'. -std::string getParameterAccessorName(llvm::StringRef name); - } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -58,7 +58,7 @@ /// Generate the code to check whether the parameter should be printed. MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const { - std::string self = getParameterAccessorName(getName()) + "()"; + std::string self = param.getAccessorName() + "()"; ctx.withSelf(self); os << tgfmt("($_self", &ctx); if (llvm::Optional defaultValue = getParam().getDefaultValue()) { @@ -718,7 +718,7 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, bool skipGuard) { const AttrOrTypeParameter ¶m = el->getParam(); - ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + ctx.withSelf(param.getAccessorName() + "()"); // Guard the printer on the presence of optional parameters and that they // aren't equal to their default values (if they have one). @@ -812,8 +812,7 @@ if (auto *ref = dyn_cast(arg)) param = ref->getArg(); os << ",\n" - << getParameterAccessorName(cast(param)->getName()) - << "()"; + << cast(param)->getParam().getAccessorName() << "()"; } os.unindent() << ");\n"; }