diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -72,10 +72,6 @@ void printAttribute( Attribute attr, DialectAsmPrinter &printer) const override; }]; - - // TODO(https://github.com/llvm/llvm-project/issues/57887): Switch to - // _Prefixed accessors. - let emitAccessorPrefix = kEmitAccessorPrefix_Both; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -65,7 +65,7 @@ let extraClassDeclaration = [{ /// Returns the block arguments. - operand_range getBlockArguments() { return targetOperands(); } + operand_range getBlockArguments() { return getTargetOperands(); } }]; let autogenSerialization = 0; @@ -161,22 +161,22 @@ /// Returns the number of arguments to the true target block. unsigned getNumTrueBlockArguments() { - return trueTargetOperands().size(); + return getTrueTargetOperands().size(); } /// Returns the number of arguments to the false target block. unsigned getNumFalseBlockArguments() { - return falseTargetOperands().size(); + return getFalseTargetOperands().size(); } // Iterator and range support for true target block arguments. operand_range getTrueBlockArguments() { - return trueTargetOperands(); + return getTrueTargetOperands(); } // Iterator and range support for false target block arguments. operand_range getFalseBlockArguments() { - return falseTargetOperands(); + return getFalseTargetOperands(); } private: diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -394,9 +394,9 @@ CArg<"FlatSymbolRefAttr", "nullptr">:$initializer), [{ $_state.addAttribute("type", type); - $_state.addAttribute(sym_nameAttrName($_state.name), sym_name); + $_state.addAttribute(getSymNameAttrName($_state.name), sym_name); if (initializer) - $_state.addAttribute(initializerAttrName($_state.name), initializer); + $_state.addAttribute(getInitializerAttrName($_state.name), initializer); }]>, OpBuilder<(ins "TypeAttr":$type, "ArrayRef":$namedAttrs), [{ @@ -412,9 +412,9 @@ CArg<"FlatSymbolRefAttr", "{}">:$initializer), [{ $_state.addAttribute("type", TypeAttr::get(type)); - $_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name)); + $_state.addAttribute(getSymNameAttrName($_state.name), $_builder.getStringAttr(sym_name)); if (initializer) - $_state.addAttribute(initializerAttrName($_state.name), initializer); + $_state.addAttribute(getInitializerAttrName($_state.name), initializer); }]> ]; @@ -424,7 +424,7 @@ let extraClassDeclaration = [{ ::mlir::spirv::StorageClass storageClass() { - return this->type().cast<::mlir::spirv::PointerType>().getStorageClass(); + return this->getType().cast<::mlir::spirv::PointerType>().getStorageClass(); } }]; } @@ -509,7 +509,7 @@ bool isOptionalSymbol() { return true; } - Optional getName() { return sym_name(); } + Optional getName() { return getSymName(); } static StringRef getVCETripleAttrName() { return "vce_triple"; } }]; diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -69,12 +69,12 @@ builder.getIntegerAttr(targetType, targetBits / sourceBits); auto idx = builder.create(loc, targetType, attr); auto lastDim = op->getOperand(op.getNumOperands() - 1); - auto indices = llvm::to_vector<4>(op.indices()); + auto indices = llvm::to_vector<4>(op.getIndices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); indices.back() = builder.create(loc, lastDim, idx); - Type t = typeConverter.convertType(op.component_ptr().getType()); - return builder.create(loc, t, op.base_ptr(), indices); + Type t = typeConverter.convertType(op.getComponentPtr().getType()); + return builder.create(loc, t, op.getBasePtr(), indices); } /// Returns the shifted `targetBits`-bit value with the given offset. @@ -371,7 +371,7 @@ // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. - assert(accessChainOp.indices().size() == 2); + assert(accessChainOp.getIndices().size() == 2); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( @@ -507,7 +507,7 @@ // 6) store 32-bit value back // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step // 4 to step 6 are done by AtomicOr as another atomic step. - assert(accessChainOp.indices().size() == 2); + assert(accessChainOp.getIndices().size() == 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -174,7 +174,7 @@ // Create the block for the header. auto *header = new Block(); // Insert the header. - loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header); + loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header); // Create the new induction variable to use. Value adapLowerBound = adaptor.getLowerBound(); @@ -197,13 +197,13 @@ // Move the blocks from the forOp into the loopOp. This is the body of the // loopOp. - rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), - getBlockIt(loopOp.body(), 2)); + rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(), + getBlockIt(loopOp.getBody(), 2)); SmallVector args(1, adaptor.getLowerBound()); args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); // Branch into it from the entry. - rewriter.setInsertionPointToEnd(&(loopOp.body().front())); + rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); rewriter.create(loc, header, args); // Generate the rest of the loop header. @@ -252,12 +252,12 @@ auto selectionOp = rewriter.create(loc, spirv::SelectionControl::None); auto *mergeBlock = - rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); + rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); rewriter.create(loc); OpBuilder::InsertionGuard guard(rewriter); auto *selectionHeaderBlock = - rewriter.createBlock(&selectionOp.body().front()); + rewriter.createBlock(&selectionOp.getBody().front()); // Inline `then` region before the merge block and branch to it. auto &thenRegion = ifOp.getThenRegion(); @@ -367,12 +367,12 @@ return failure(); // Move the while before block as the initial loop header block. - rewriter.inlineRegionBefore(beforeRegion, loopOp.body(), - getBlockIt(loopOp.body(), 1)); + rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(), + getBlockIt(loopOp.getBody(), 1)); // Move the while after block as the initial loop body block. - rewriter.inlineRegionBefore(afterRegion, loopOp.body(), - getBlockIt(loopOp.body(), 2)); + rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(), + getBlockIt(loopOp.getBody(), 2)); // Jump from the loop entry block to the loop header block. rewriter.setInsertionPointToEnd(&entryBlock); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -89,7 +89,7 @@ op->getAttrOfType(descriptorSetName()); IntegerAttr binding = op->getAttrOfType(bindingName()); return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", - kernelModuleName.str(), op.sym_name().str(), + kernelModuleName.str(), op.getSymName().str(), std::to_string(descriptorSet.getInt()), std::to_string(binding.getInt())); } @@ -126,14 +126,14 @@ /// Encodes the SPIR-V module's symbolic name into the name of the entry point /// function. static LogicalResult encodeKernelName(spirv::ModuleOp module) { - StringRef spvModuleName = *module.sym_name(); + StringRef spvModuleName = *module.getSymName(); // We already know that the module contains exactly one entry point function // based on `getKernelGlobalVariables()` call. Update this function's name // to: // {spv_module_name}_{function_name} auto entryPoint = *module.getOps().begin(); - StringRef funcName = entryPoint.fn(); - auto funcOp = module.lookupSymbol(entryPoint.fnAttr()); + StringRef funcName = entryPoint.getFn(); + auto funcOp = module.lookupSymbol(entryPoint.getFnAttr()); StringAttr newFuncName = StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) @@ -236,7 +236,7 @@ // LLVM dialect global variable. spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = - spirvGlobal.type().cast().getPointeeType(); + spirvGlobal.getType().cast().getPointeeType(); auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -228,14 +228,14 @@ if (!dstType) return failure(); rewriter.replaceOpWithNewOp( - loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, + loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, isVolatile, isNonTemporal); return success(); } auto storeOp = cast(op); spirv::StoreOpAdaptor adaptor(operands); - rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), - adaptor.ptr(), alignment, + rewriter.replaceOpWithNewOp(storeOp, adaptor.getValue(), + adaptor.getPtr(), alignment, isVolatile, isNonTemporal); return success(); } @@ -305,19 +305,19 @@ LogicalResult matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = typeConverter.convertType(op.component_ptr().getType()); + auto dstType = typeConverter.convertType(op.getComponentPtr().getType()); if (!dstType) return failure(); // To use GEP we need to add a first 0 index to go through the pointer. - auto indices = llvm::to_vector<4>(adaptor.indices()); - Type indexType = op.indices().front().getType(); + auto indices = llvm::to_vector<4>(adaptor.getIndices()); + Type indexType = op.getIndices().front().getType(); auto llvmIndexType = typeConverter.convertType(indexType); if (!llvmIndexType) return failure(); Value zero = rewriter.create( op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); - rewriter.replaceOpWithNewOp(op, dstType, adaptor.base_ptr(), + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getBasePtr(), indices); return success(); } @@ -330,10 +330,10 @@ LogicalResult matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = typeConverter.convertType(op.pointer().getType()); + auto dstType = typeConverter.convertType(op.getPointer().getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(op, dstType, op.variable()); + rewriter.replaceOpWithNewOp(op, dstType, op.getVariable()); return success(); } }; @@ -353,9 +353,9 @@ Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. - Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, + Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, typeConverter, rewriter); - Value count = processCountOrOffset(loc, op.count(), srcType, dstType, + Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set outside [Offset, Offset + Count - 1]. @@ -372,9 +372,9 @@ // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = - rewriter.create(loc, dstType, op.base(), mask); + rewriter.create(loc, dstType, op.getBase(), mask); Value insertShiftedByOffset = - rewriter.create(loc, dstType, op.insert(), offset); + rewriter.create(loc, dstType, op.getInsert(), offset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); @@ -408,14 +408,14 @@ auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); if (srcType.isa()) { - auto dstElementsAttr = constOp.value().cast(); + auto dstElementsAttr = constOp.getValue().cast(); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } - auto srcAttr = constOp.value().cast(); + auto srcAttr = constOp.getValue().cast(); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); @@ -441,9 +441,9 @@ Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. - Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, + Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, typeConverter, rewriter); - Value count = processCountOrOffset(loc, op.count(), srcType, dstType, + Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, typeConverter, rewriter); // Create a constant that holds the size of the `Base`. @@ -468,7 +468,7 @@ Value amountToShiftLeft = rewriter.create(loc, dstType, size, countPlusOffset); Value baseShiftedLeft = rewriter.create( - loc, dstType, op.base(), amountToShiftLeft); + loc, dstType, op.getBase(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = @@ -494,9 +494,9 @@ Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. - Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, + Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, typeConverter, rewriter); - Value count = processCountOrOffset(loc, op.count(), srcType, dstType, + Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set at [0, Count - 1]. @@ -508,7 +508,7 @@ // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = - rewriter.create(loc, dstType, op.base(), offset); + rewriter.create(loc, dstType, op.getBase(), offset); rewriter.replaceOpWithNewOp(op, dstType, shiftedBase, mask); return success(); } @@ -538,20 +538,20 @@ ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. ElementsAttr branchWeights = nullptr; - if (auto weights = op.branch_weights()) { + if (auto weights = op.getBranchWeights()) { VectorType weightType = VectorType::get(2, rewriter.getI32Type()); branchWeights = DenseElementsAttr::get(weightType, weights->getValue()); } rewriter.replaceOpWithNewOp( - op, op.condition(), op.getTrueBlockArguments(), + op, op.getCondition(), op.getTrueBlockArguments(), op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), op.getFalseBlock()); return success(); } }; -/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type +/// Converts `spv.getCompositeExtract` to `llvm.extractvalue` if the container type /// is an aggregate type (struct or array). Otherwise, converts to /// `llvm.extractelement` that operates on vectors. class CompositeExtractPattern @@ -566,23 +566,23 @@ if (!dstType) return failure(); - Type containerType = op.composite().getType(); + Type containerType = op.getComposite().getType(); if (containerType.isa()) { Location loc = op.getLoc(); - IntegerAttr value = op.indices()[0].cast(); + IntegerAttr value = op.getIndices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( - op, dstType, adaptor.composite(), index); + op, dstType, adaptor.getComposite(), index); return success(); } rewriter.replaceOpWithNewOp( - op, adaptor.composite(), LLVM::convertArrayToIndices(op.indices())); + op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices())); return success(); } }; -/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type +/// Converts `spv.getCompositeInsert` to `llvm.insertvalue` if the container type /// is an aggregate type (struct or array). Otherwise, converts to /// `llvm.insertelement` that operates on vectors. class CompositeInsertPattern @@ -597,19 +597,19 @@ if (!dstType) return failure(); - Type containerType = op.composite().getType(); + Type containerType = op.getComposite().getType(); if (containerType.isa()) { Location loc = op.getLoc(); - IntegerAttr value = op.indices()[0].cast(); + IntegerAttr value = op.getIndices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( - op, dstType, adaptor.composite(), adaptor.object(), index); + op, dstType, adaptor.getComposite(), adaptor.getObject(), index); return success(); } rewriter.replaceOpWithNewOp( - op, adaptor.composite(), adaptor.object(), - LLVM::convertArrayToIndices(op.indices())); + op, adaptor.getComposite(), adaptor.getObject(), + LLVM::convertArrayToIndices(op.getIndices())); return success(); } }; @@ -647,14 +647,14 @@ // this entry point's execution mode. We set it to be: // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} ModuleOp module = op->getParentOfType(); - spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr(); + spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); std::string moduleName; if (module.getName().has_value()) - moduleName = "_" + module.getName().value().str(); + moduleName = "_" + module.getName()->str(); else moduleName = ""; std::string executionModeInfoName = llvm::formatv( - "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(), + "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(), static_cast(executionModeAttr.getValue())); MLIRContext *context = rewriter.getContext(); @@ -669,7 +669,7 @@ auto llvmI32Type = IntegerType::get(context, 32); SmallVector fields; fields.push_back(llvmI32Type); - ArrayAttr values = op.values(); + ArrayAttr values = op.getValues(); if (!values.empty()) { auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); fields.push_back(arrayType); @@ -722,10 +722,10 @@ ConversionPatternRewriter &rewriter) const override { // Currently, there is no support of initialization with a constant value in // SPIR-V dialect. Specialization constants are not considered as well. - if (op.initializer()) + if (op.getInitializer()) return failure(); - auto srcType = op.type().cast(); + auto srcType = op.getType().cast(); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) return failure(); @@ -759,12 +759,12 @@ ? LLVM::Linkage::Private : LLVM::Linkage::External; auto newGlobalOp = rewriter.replaceOpWithNewOp( - op, dstType, isConstant, linkage, op.sym_name(), Attribute(), + op, dstType, isConstant, linkage, op.getSymName(), Attribute(), /*alignment=*/0); // Attach location attribute if applicable - if (op.locationAttr()) - newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr()); + if (op.getLocationAttr()) + newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); return success(); } @@ -781,7 +781,7 @@ matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type fromType = operation.operand().getType(); + Type fromType = operation.getOperand().getType(); Type toType = operation.getType(); auto dstType = this->typeConverter.convertType(toType); @@ -839,8 +839,8 @@ return failure(); rewriter.template replaceOpWithNewOp( - operation, dstType, predicate, operation.operand1(), - operation.operand2()); + operation, dstType, predicate, operation.getOperand1(), + operation.getOperand2()); return success(); } }; @@ -860,8 +860,8 @@ return failure(); rewriter.template replaceOpWithNewOp( - operation, dstType, predicate, operation.operand1(), - operation.operand2()); + operation, dstType, predicate, operation.getOperand1(), + operation.getOperand2()); return success(); } }; @@ -881,7 +881,7 @@ Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value sqrt = rewriter.create(loc, dstType, op.operand()); + Value sqrt = rewriter.create(loc, dstType, op.getOperand()); rewriter.replaceOpWithNewOp(op, dstType, one, sqrt); return success(); } @@ -896,20 +896,20 @@ LogicalResult matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.memory_access()) { + if (!op.getMemoryAccess()) { return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, this->typeConverter, /*alignment=*/0, /*isVolatile=*/false, /*isNonTemporal=*/false); } - auto memoryAccess = *op.memory_access(); + auto memoryAccess = *op.getMemoryAccess(); switch (memoryAccess) { case spirv::MemoryAccess::Aligned: case spirv::MemoryAccess::None: case spirv::MemoryAccess::Nontemporal: case spirv::MemoryAccess::Volatile: { unsigned alignment = - memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; + memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, @@ -946,7 +946,7 @@ srcType.template cast(), minusOne)) : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, - notOp.operand(), mask); + notOp.getOperand(), mask); return success(); } }; @@ -1047,7 +1047,7 @@ matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // There is no support of loop control at the moment. - if (loopOp.loop_control() != spirv::LoopControl::None) + if (loopOp.getLoopControl() != spirv::LoopControl::None) return failure(); Location loc = loopOp.getLoc(); @@ -1077,7 +1077,7 @@ rewriter.setInsertionPointToEnd(mergeBlock); rewriter.create(loc, terminatorOperands, endBlock); - rewriter.inlineRegionBefore(loopOp.body(), endBlock); + rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); return success(); } @@ -1096,14 +1096,14 @@ // There is no support for `Flatten` or `DontFlatten` selection control at // the moment. This are just compiler hints and can be performed during the // optimization passes. - if (op.selection_control() != spirv::SelectionControl::None) + if (op.getSelectionControl() != spirv::SelectionControl::None) return failure(); // `spv.mlir.selection` should have at least two blocks: one selection // header block and one merge block. If no blocks are present, or control // flow branches straight to merge block (two blocks are present), the op is // redundant and it is erased. - if (op.body().getBlocks().size() <= 2) { + if (op.getBody().getBlocks().size() <= 2) { rewriter.eraseOp(op); return success(); } @@ -1140,11 +1140,11 @@ Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, condBrOp.condition(), trueBlock, - condBrOp.trueTargetOperands(), falseBlock, - condBrOp.falseTargetOperands()); + rewriter.create(loc, condBrOp.getCondition(), trueBlock, + condBrOp.getTrueTargetOperands(), falseBlock, + condBrOp.getFalseTargetOperands()); - rewriter.inlineRegionBefore(op.body(), continueBlock); + rewriter.inlineRegionBefore(op.getBody(), continueBlock); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } @@ -1167,8 +1167,8 @@ if (!dstType) return failure(); - Type op1Type = operation.operand1().getType(); - Type op2Type = operation.operand2().getType(); + Type op1Type = operation.getOperand1().getType(); + Type op2Type = operation.getOperand2().getType(); if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, @@ -1180,13 +1180,13 @@ Value extended; if (isUnsignedIntegerOrVector(op2Type)) { extended = rewriter.template create(loc, dstType, - adaptor.operand2()); + adaptor.getOperand2()); } else { extended = rewriter.template create(loc, dstType, - adaptor.operand2()); + adaptor.getOperand2()); } Value result = rewriter.template create( - loc, dstType, adaptor.operand1(), extended); + loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(operation, result); return success(); } @@ -1204,8 +1204,8 @@ return failure(); Location loc = tanOp.getLoc(); - Value sin = rewriter.create(loc, dstType, tanOp.operand()); - Value cos = rewriter.create(loc, dstType, tanOp.operand()); + Value sin = rewriter.create(loc, dstType, tanOp.getOperand()); + Value cos = rewriter.create(loc, dstType, tanOp.getOperand()); rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); return success(); } @@ -1232,7 +1232,7 @@ Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = - rewriter.create(loc, dstType, two, tanhOp.operand()); + rewriter.create(loc, dstType, two, tanhOp.getOperand()); Value exponential = rewriter.create(loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = @@ -1255,7 +1255,7 @@ auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. auto pointerTo = srcType.cast().getPointeeType(); - auto init = varOp.initializer(); + auto init = varOp.getInitializer(); if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) return failure(); @@ -1270,7 +1270,7 @@ return success(); } Value allocated = rewriter.create(loc, dstType, size); - rewriter.create(loc, adaptor.initializer(), allocated); + rewriter.create(loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } @@ -1305,7 +1305,7 @@ // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); - switch (funcOp.function_control()) { + switch (funcOp.getFunctionControl()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ @@ -1374,9 +1374,9 @@ matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto components = adaptor.components(); - auto vector1 = adaptor.vector1(); - auto vector2 = adaptor.vector2(); + auto components = adaptor.getComponents(); + auto vector1 = adaptor.getVector1(); + auto vector2 = adaptor.getVector2(); int vector1Size = vector1.getType().cast().getNumElements(); int vector2Size = vector2.getType().cast().getNumElements(); if (vector1Size == vector2Size) { @@ -1589,8 +1589,8 @@ // SPIR-V module has a name, add it at the beginning. auto moduleAndName = spvModule.getName().has_value() - ? spvModule.getName().value().str() + "_" + op.sym_name().str() - : op.sym_name().str(); + ? spvModule.getName()->str() + "_" + op.getSymName().str() + : op.getSymName().str(); std::string name = llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, std::to_string(descriptorSet.getInt()), diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -88,19 +88,19 @@ LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp, PatternRewriter &rewriter) const override { auto parentAccessChainOp = dyn_cast_or_null( - accessChainOp.base_ptr().getDefiningOp()); + accessChainOp.getBasePtr().getDefiningOp()); if (!parentAccessChainOp) { return failure(); } // Combine indices. - SmallVector indices(parentAccessChainOp.indices()); - indices.append(accessChainOp.indices().begin(), - accessChainOp.indices().end()); + SmallVector indices(parentAccessChainOp.getIndices()); + indices.append(accessChainOp.getIndices().begin(), + accessChainOp.getIndices().end()); rewriter.replaceOpWithNewOp( - accessChainOp, parentAccessChainOp.base_ptr(), indices); + accessChainOp, parentAccessChainOp.getBasePtr(), indices); return success(); } @@ -126,23 +126,24 @@ //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { - if (auto insertOp = composite().getDefiningOp()) { - if (indices() == insertOp.indices()) - return insertOp.object(); + if (auto insertOp = + getComposite().getDefiningOp()) { + if (getIndices() == insertOp.getIndices()) + return insertOp.getObject(); } if (auto constructOp = - composite().getDefiningOp()) { + getComposite().getDefiningOp()) { auto type = constructOp.getType().cast(); - if (indices().size() == 1 && - constructOp.constituents().size() == type.getNumElements()) { - auto i = indices().begin()->cast(); - return constructOp.constituents()[i.getValue().getSExtValue()]; + if (getIndices().size() == 1 && + constructOp.getConstituents().size() == type.getNumElements()) { + auto i = getIndices().begin()->cast(); + return constructOp.getConstituents()[i.getValue().getSExtValue()]; } } auto indexVector = - llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) { + llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) { return static_cast(attr.cast().getInt()); })); return extractCompositeElement(operands[0], indexVector); @@ -154,7 +155,7 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "spv.Constant has no operands"); - return value(); + return getValue(); } //===----------------------------------------------------------------------===// @@ -164,8 +165,8 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "spv.IAdd expects two operands"); // x + 0 = x - if (matchPattern(operand2(), m_Zero())) - return operand1(); + if (matchPattern(getOperand2(), m_Zero())) + return getOperand1(); // According to the SPIR-V spec: // @@ -183,11 +184,11 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "spv.IMul expects two operands"); // x * 0 == 0 - if (matchPattern(operand2(), m_Zero())) - return operand2(); + if (matchPattern(getOperand2(), m_Zero())) + return getOperand2(); // x * 1 = x - if (matchPattern(operand2(), m_One())) - return operand1(); + if (matchPattern(getOperand2(), m_One())) + return getOperand1(); // According to the SPIR-V spec: // @@ -204,7 +205,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { // x - x = 0 - if (operand1() == operand2()) + if (getOperand1() == getOperand2()) return Builder(getContext()).getIntegerAttr(getType(), 0); // According to the SPIR-V spec: @@ -226,7 +227,7 @@ if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { // x && true = x if (rhs.value()) - return operand1(); + return getOperand1(); // x && false = false if (!rhs.value()) @@ -262,7 +263,7 @@ // x || false = x if (!rhs.value()) - return operand1(); + return getOperand1(); } return Attribute(); @@ -339,8 +340,8 @@ cast(trueBlock->front())->getAttrs(); auto selectOp = rewriter.create( - selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(), - trueValue, falseValue); + selectionOp.getLoc(), trueValue.getType(), + brConditionalOp.getCondition(), trueValue, falseValue); rewriter.create(selectOp.getLoc(), ptrValue, selectOp.getResult(), storeOpAttributes); @@ -371,13 +372,13 @@ // Returns a source value for the given block. Value getSrcValue(Block *block) const { auto storeOp = cast(block->front()); - return storeOp.value(); + return storeOp.getValue(); } // Returns a destination value for the given block. Value getDstPtr(Block *block) const { auto storeOp = cast(block->front()); - return storeOp.ptr(); + return storeOp.getPtr(); } }; @@ -406,14 +407,14 @@ // "Before version 1.4, Result Type must be a pointer, scalar, or vector. // Starting with version 1.4, Result Type can additionally be a composite type // other than a vector." - bool isScalarOrVector = trueBrStoreOp.value() + bool isScalarOrVector = trueBrStoreOp.getValue() .getType() .cast() .isScalarOrVector(); // Check that each `spv.Store` uses the same pointer, memory access // attributes and a valid type of the value. - if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || + if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) || !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) { return failure(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -106,7 +106,7 @@ // Replace the values directly with the return operands. assert(valuesToRepl.size() == 1 && "spv.ReturnValue expected to only handle one result"); - valuesToRepl.front().replaceAllUsesWith(retValOp.value()); + valuesToRepl.front().replaceAllUsesWith(retValOp.getValue()); } }; } // namespace diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -136,7 +136,7 @@ if (!constOp) { return failure(); } - auto valueAttr = constOp.value(); + auto valueAttr = constOp.getValue(); auto integerValueAttr = valueAttr.dyn_cast(); if (!integerValueAttr) { return failure(); @@ -313,7 +313,7 @@ Optional alignmentAttrValue = None) { // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue - : memoryOp.memory_access())) { + : memoryOp.getMemoryAccess())) { elidedAttrs.push_back(kMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; @@ -321,7 +321,7 @@ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue - : memoryOp.alignment())) { + : memoryOp.getAlignment())) { elidedAttrs.push_back(kAlignmentAttrName); printer << ", " << alignment; } @@ -346,7 +346,7 @@ // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue - : memoryOp.memory_access())) { + : memoryOp.getMemoryAccess())) { elidedAttrs.push_back(kSourceMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; @@ -354,7 +354,7 @@ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue - : memoryOp.alignment())) { + : memoryOp.getAlignment())) { elidedAttrs.push_back(kSourceAlignmentAttrName); printer << ", " << alignment; } @@ -1086,17 +1086,17 @@ template static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { - printer << ' ' << op.base_ptr() << '[' << indices - << "] : " << op.base_ptr().getType() << ", " << indices.getTypes(); + printer << ' ' << op.getBasePtr() << '[' << indices + << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); } void spirv::AccessChainOp::print(OpAsmPrinter &printer) { - printAccessChain(*this, indices(), printer); + printAccessChain(*this, getIndices(), printer); } template static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { - auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(), + auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), indices, accessChainOp.getLoc()); if (!resultType) return failure(); @@ -1116,7 +1116,7 @@ } LogicalResult spirv::AccessChainOp::verify() { - return verifyAccessChain(*this, indices()); + return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// @@ -1125,17 +1125,17 @@ void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, spirv::GlobalVariableOp var) { - build(builder, state, var.type(), SymbolRefAttr::get(var)); + build(builder, state, var.getType(), SymbolRefAttr::get(var)); } LogicalResult spirv::AddressOfOp::verify() { auto varOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), - variableAttr())); + getVariableAttr())); if (!varOp) { return emitOpError("expected spv.GlobalVariable symbol"); } - if (pointer().getType() != varOp.type()) { + if (getPointer().getType() != varOp.getType()) { return emitOpError( "result type mismatch with the referenced global variable's type"); } @@ -1144,10 +1144,10 @@ template static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" - << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" - << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " - << atomOp.getOperands() << " : " << atomOp.pointer().getType(); + printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \"" + << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \"" + << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.getPointer().getType(); } static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, @@ -1188,18 +1188,18 @@ // "The type of Value must be the same as Result Type. The type of the value // pointed to by Pointer must be the same as Result Type. This type must also // match the type of Comparator." - if (atomOp.getType() != atomOp.value().getType()) + if (atomOp.getType() != atomOp.getValue().getType()) return atomOp.emitOpError("value operand must have the same type as the op " "result, but found ") - << atomOp.value().getType() << " vs " << atomOp.getType(); + << atomOp.getValue().getType() << " vs " << atomOp.getType(); - if (atomOp.getType() != atomOp.comparator().getType()) + if (atomOp.getType() != atomOp.getComparator().getType()) return atomOp.emitOpError( "comparator operand must have the same type as the op " "result, but found ") - << atomOp.comparator().getType() << " vs " << atomOp.getType(); + << atomOp.getComparator().getType() << " vs " << atomOp.getType(); - Type pointeeType = atomOp.pointer() + Type pointeeType = atomOp.getPointer() .getType() .template cast() .getPointeeType(); @@ -1268,9 +1268,9 @@ //===----------------------------------------------------------------------===// void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(memory_scope()) << "\" \"" - << stringifyMemorySemantics(semantics()) << "\" " << getOperands() - << " : " << pointer().getType(); + printer << " \"" << stringifyScope(getMemoryScope()) << "\" \"" + << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands() + << " : " << getPointer().getType(); } ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, @@ -1302,13 +1302,13 @@ } LogicalResult spirv::AtomicExchangeOp::verify() { - if (getType() != value().getType()) + if (getType() != getValue().getType()) return emitOpError("value operand must have the same type as the op " "result, but found ") - << value().getType() << " vs " << getType(); + << getValue().getType() << " vs " << getType(); Type pointeeType = - pointer().getType().cast().getPointeeType(); + getPointer().getType().cast().getPointeeType(); if (getType() != pointeeType) return emitOpError("pointer operand's pointee type must have the same " "as the op result type, but found ") @@ -1500,8 +1500,8 @@ LogicalResult spirv::BitcastOp::verify() { // TODO: The SPIR-V spec validation rules are different for different // versions. - auto operandType = operand().getType(); - auto resultType = result().getType(); + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); if (operandType == resultType) { return emitError("result type must be different from operand type"); } @@ -1530,8 +1530,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::PtrCastToGenericOp::verify() { - auto operandType = pointer().getType().cast(); - auto resultType = result().getType().cast(); + auto operandType = getPointer().getType().cast(); + auto resultType = getResult().getType().cast(); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Workgroup && @@ -1558,8 +1558,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrOp::verify() { - auto operandType = pointer().getType().cast(); - auto resultType = result().getType().cast(); + auto operandType = getPointer().getType().cast(); + auto resultType = getResult().getType().cast(); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1586,8 +1586,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { - auto operandType = pointer().getType().cast(); - auto resultType = result().getType().cast(); + auto operandType = getPointer().getType().cast(); + auto resultType = getResult().getType().cast(); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -1615,7 +1615,7 @@ SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return SuccessorOperands(0, targetOperandsMutable()); + return SuccessorOperands(0, getTargetOperandsMutable()); } //===----------------------------------------------------------------------===// @@ -1625,8 +1625,9 @@ SuccessorOperands spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); - return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable() - : falseTargetOperandsMutable()); + return SuccessorOperands(index == kTrueIndex + ? getTrueTargetOperandsMutable() + : getFalseTargetOperandsMutable()); } ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, @@ -1681,9 +1682,9 @@ } void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { - printer << ' ' << condition(); + printer << ' ' << getCondition(); - if (auto weights = branch_weights()) { + if (auto weights = getBranchWeights()) { printer << " ["; llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { printer << a.cast().getInt(); @@ -1698,7 +1699,7 @@ } LogicalResult spirv::BranchConditionalOp::verify() { - if (auto weights = branch_weights()) { + if (auto weights = getBranchWeights()) { if (weights->getValue().size() != 2) { return emitOpError("must have exactly two branch weights"); } @@ -1717,7 +1718,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { auto cType = getType().cast(); - operand_range constituents = this->constituents(); + operand_range constituents = this->getConstituents(); if (auto coopType = cType.dyn_cast()) { if (constituents.size() != 1) @@ -1828,13 +1829,14 @@ } void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { - printer << ' ' << composite() << indices() << " : " << composite().getType(); + printer << ' ' << getComposite() << getIndices() << " : " + << getComposite().getType(); } LogicalResult spirv::CompositeExtractOp::verify() { - auto indicesArrayAttr = indices().dyn_cast(); + auto indicesArrayAttr = getIndices().dyn_cast(); auto resultType = - getElementType(composite().getType(), indicesArrayAttr, getLoc()); + getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!resultType) return failure(); @@ -1875,29 +1877,30 @@ } LogicalResult spirv::CompositeInsertOp::verify() { - auto indicesArrayAttr = indices().dyn_cast(); + auto indicesArrayAttr = getIndices().dyn_cast(); auto objectType = - getElementType(composite().getType(), indicesArrayAttr, getLoc()); + getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!objectType) return failure(); - if (objectType != object().getType()) { + if (objectType != getObject().getType()) { return emitOpError("object operand type should be ") - << objectType << ", but found " << object().getType(); + << objectType << ", but found " << getObject().getType(); } - if (composite().getType() != getType()) { + if (getComposite().getType() != getType()) { return emitOpError("result type should be the same as " "the composite type, but found ") - << composite().getType() << " vs " << getType(); + << getComposite().getType() << " vs " << getType(); } return success(); } void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { - printer << " " << object() << ", " << composite() << indices() << " : " - << object().getType() << " into " << composite().getType(); + printer << " " << getObject() << ", " << getComposite() << getIndices() + << " : " << getObject().getType() << " into " + << getComposite().getType(); } //===----------------------------------------------------------------------===// @@ -1922,7 +1925,7 @@ } void spirv::ConstantOp::print(OpAsmPrinter &printer) { - printer << ' ' << value(); + printer << ' ' << getValue(); if (getType().isa()) printer << " : " << getType(); } @@ -1989,7 +1992,7 @@ // ODS already generates checks to make sure the result type is valid. We just // need to additionally check that the value's attribute type is consistent // with the result type. - return verifyConstantType(*this, valueAttr(), getType()); + return verifyConstantType(*this, getValueAttr(), getType()); } bool spirv::ConstantOp::isBuildableWith(Type type) { @@ -2081,7 +2084,7 @@ IntegerType intTy = type.dyn_cast(); - if (IntegerAttr intCst = value().dyn_cast()) { + if (IntegerAttr intCst = getValue().dyn_cast()) { if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } @@ -2115,7 +2118,7 @@ llvm::function_ref setNameFn) { SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); - specialName << variable() << "_addr"; + specialName << getVariable() << "_addr"; setNameFn(getResult(), specialName.str()); } @@ -2124,7 +2127,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ControlBarrierOp::verify() { - return verifyMemorySemantics(getOperation(), memory_semantics()); + return verifyMemorySemantics(getOperation(), getMemorySemantics()); } //===----------------------------------------------------------------------===// @@ -2208,9 +2211,9 @@ } void spirv::EntryPointOp::print(OpAsmPrinter &printer) { - printer << " \"" << stringifyExecutionModel(execution_model()) << "\" "; - printer.printSymbolName(fn()); - auto interfaceVars = interface().getValue(); + printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; + printer.printSymbolName(getFn()); + auto interfaceVars = getInterface().getValue(); if (!interfaceVars.empty()) { printer << ", "; llvm::interleaveComma(interfaceVars, printer); @@ -2262,9 +2265,9 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; - printer.printSymbolName(fn()); - printer << " \"" << stringifyExecutionMode(execution_mode()) << "\""; - auto values = this->values(); + printer.printSymbolName(getFn()); + printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; + auto values = this->getValues(); if (values.empty()) return; printer << ", "; @@ -2351,19 +2354,19 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) { // Print function name, signature, and control. printer << " "; - printer.printSymbolName(sym_name()); + printer.printSymbolName(getSymName()); auto fnType = getFunctionType(); function_interface_impl::printFunctionSignature( printer, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); - printer << " \"" << spirv::stringifyFunctionControl(function_control()) + printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( printer, *this, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); // Print the body if this is not an external function. - Region &body = this->body(); + Region &body = this->getBody(); if (!body.empty()) { printer << ' '; printer.printRegion(body, /*printEntryBlockArgs=*/false, @@ -2394,7 +2397,7 @@ "returns 1 value but enclosing function requires ") << fnType.getNumResults() << " results"; - auto retOperandType = retOp.value().getType(); + auto retOperandType = retOp.getValue().getType(); auto fnResultType = fnType.getResult(0); if (retOperandType != fnResultType) return retOp.emitOpError(" return value's type (") @@ -2424,7 +2427,7 @@ // CallableOpInterface Region *spirv::FuncOp::getCallableRegion() { - return isExternal() ? nullptr : &body(); + return isExternal() ? nullptr : &getBody(); } // CallableOpInterface @@ -2437,7 +2440,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::FunctionCallOp::verify() { - auto fnName = calleeAttr(); + auto fnName = getCalleeAttr(); auto funcOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); @@ -2490,7 +2493,7 @@ } Operation::operand_range spirv::FunctionCallOp::getArgOperands() { - return arguments(); + return getArguments(); } //===----------------------------------------------------------------------===// @@ -2599,11 +2602,11 @@ // Print variable name. printer << ' '; - printer.printSymbolName(sym_name()); + printer.printSymbolName(getSymName()); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); // Print optional initializer - if (auto initializer = this->initializer()) { + if (auto initializer = this->getInitializer()) { printer << " " << kInitializerAttrName << '('; printer.printSymbolName(*initializer); printer << ')'; @@ -2612,7 +2615,7 @@ elidedAttrs.push_back(kTypeAttrName); printVariableDecorations(*this, printer, elidedAttrs); - printer << " : " << type(); + printer << " : " << getType(); } LogicalResult spirv::GlobalVariableOp::verify() { @@ -2649,11 +2652,11 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GroupBroadcastOp::verify() { - spirv::Scope scope = execution_scope(); + spirv::Scope scope = getExecutionScope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - if (auto localIdTy = localid().getType().dyn_cast()) + if (auto localIdTy = getLocalid().getType().dyn_cast()) if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) return emitOpError("localid is a vector and can be with only " " 2 or 3 components, actual number is ") @@ -2667,7 +2670,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformBallotOp::verify() { - spirv::Scope scope = execution_scope(); + spirv::Scope scope = getExecutionScope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); @@ -2679,7 +2682,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { - spirv::Scope scope = execution_scope(); + spirv::Scope scope = getExecutionScope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); @@ -2690,7 +2693,7 @@ targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); if (targetEnv.getVersion() < spirv::Version::V_1_5) { - auto *idOp = id().getDefiningOp(); + auto *idOp = getId().getDefiningOp(); if (!idOp || !isa(idOp)) // for spec constant return emitOpError("id must be the result of a constant op"); @@ -2705,7 +2708,7 @@ template static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { - spirv::Scope scope = op.execution_scope(); + spirv::Scope scope = op.getExecutionScope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); @@ -2756,11 +2759,11 @@ } void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { - printer << " " << ptr() << " : " << getType(); + printer << " " << getPtr() << " : " << getType(); } LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { - if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) + if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return success(); @@ -2795,11 +2798,12 @@ } void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { - printer << " " << ptr() << ", " << value() << " : " << value().getType(); + printer << " " << getPtr() << ", " << getValue() << " : " + << getValue().getType(); } LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { - if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) + if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return success(); @@ -2810,7 +2814,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformElectOp::verify() { - spirv::Scope scope = execution_scope(); + spirv::Scope scope = getExecutionScope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); @@ -2986,7 +2990,7 @@ if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); - if (!llvm::all_equal({operand1().getType(), operand2().getType(), + if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(), resultType.getElementType(0), resultType.getElementType(1)})) return emitOpError( @@ -3035,7 +3039,7 @@ if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); - if (!llvm::all_equal({operand1().getType(), operand2().getType(), + if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(), resultType.getElementType(0), resultType.getElementType(1)})) return emitOpError( @@ -3111,8 +3115,8 @@ void spirv::LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - ptr().getType().cast().getStorageClass()); - printer << " \"" << sc << "\" " << ptr(); + getPtr().getType().cast().getStorageClass()); + printer << " \"" << sc << "\" " << getPtr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -3124,7 +3128,7 @@ // SPIR-V spec : "Result Type is the type of the loaded object. It must be a // type with fixed size; i.e., it cannot be, nor include, any // OpTypeRuntimeArray types." - if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) { + if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { return failure(); } return verifyMemoryAccessAttribute(*this); @@ -3148,7 +3152,7 @@ } void spirv::LoopOp::print(OpAsmPrinter &printer) { - auto control = loop_control(); + auto control = getLoopControl(); if (control != spirv::LoopControl::None) printer << " control(" << spirv::stringifyLoopControl(control) << ")"; printer << ' '; @@ -3253,33 +3257,33 @@ } Block *spirv::LoopOp::getEntryBlock() { - assert(!body().empty() && "op region should not be empty!"); - return &body().front(); + assert(!getBody().empty() && "op region should not be empty!"); + return &getBody().front(); } Block *spirv::LoopOp::getHeaderBlock() { - assert(!body().empty() && "op region should not be empty!"); + assert(!getBody().empty() && "op region should not be empty!"); // The second block is the loop header block. - return &*std::next(body().begin()); + return &*std::next(getBody().begin()); } Block *spirv::LoopOp::getContinueBlock() { - assert(!body().empty() && "op region should not be empty!"); + assert(!getBody().empty() && "op region should not be empty!"); // The second to last block is the loop continue block. - return &*std::prev(body().end(), 2); + return &*std::prev(getBody().end(), 2); } Block *spirv::LoopOp::getMergeBlock() { - assert(!body().empty() && "op region should not be empty!"); + assert(!getBody().empty() && "op region should not be empty!"); // The last block is the loop merge block. - return &body().back(); + return &getBody().back(); } void spirv::LoopOp::addEntryAndMergeBlock() { - assert(body().empty() && "entry and merge block already exist"); - body().push_back(new Block()); + assert(getBody().empty() && "entry and merge block already exist"); + getBody().push_back(new Block()); auto *mergeBlock = new Block(); - body().push_back(mergeBlock); + getBody().push_back(mergeBlock); OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); // Add a spv.mlir.merge op into the merge block. @@ -3291,7 +3295,7 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MemoryBarrierOp::verify() { - return verifyMemorySemantics(getOperation(), memory_semantics()); + return verifyMemorySemantics(getOperation(), getMemorySemantics()); } //===----------------------------------------------------------------------===// @@ -3390,14 +3394,14 @@ SmallVector elidedAttrs; - printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " " - << spirv::stringifyMemoryModel(memory_model()); + printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " + << spirv::stringifyMemoryModel(getMemoryModel()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, mlir::SymbolTable::getSymbolAttrName()}); - if (Optional triple = vce_triple()) { + if (Optional triple = getVceTriple()) { printer << " requires " << *triple; elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); } @@ -3421,12 +3425,12 @@ // duplicated in EntryPointOps. Also verify that the interface specified // comes from globalVariables here to make this check cheaper. if (auto entryPointOp = dyn_cast(op)) { - auto funcOp = table.lookup(entryPointOp.fn()); + auto funcOp = table.lookup(entryPointOp.getFn()); if (!funcOp) { return entryPointOp.emitError("function '") - << entryPointOp.fn() << "' not found in 'spv.module'"; + << entryPointOp.getFn() << "' not found in 'spv.module'"; } - if (auto interface = entryPointOp.interface()) { + if (auto interface = entryPointOp.getInterface()) { for (Attribute varRef : interface) { auto varSymRef = varRef.dyn_cast(); if (!varSymRef) { @@ -3446,7 +3450,7 @@ } auto key = std::pair( - funcOp, entryPointOp.execution_model()); + funcOp, entryPointOp.getExecutionModel()); auto entryPtIt = entryPoints.find(key); if (entryPtIt != entryPoints.end()) { return entryPointOp.emitError("duplicate of a previous EntryPointOp"); @@ -3475,23 +3479,23 @@ LogicalResult spirv::ReferenceOfOp::verify() { auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( - (*this)->getParentOp(), spec_constAttr()); + (*this)->getParentOp(), getSpecConstAttr()); Type constType; auto specConstOp = dyn_cast_or_null(specConstSym); if (specConstOp) - constType = specConstOp.default_value().getType(); + constType = specConstOp.getDefaultValue().getType(); auto specConstCompositeOp = dyn_cast_or_null(specConstSym); if (specConstCompositeOp) - constType = specConstCompositeOp.type(); + constType = specConstCompositeOp.getType(); if (!specConstOp && !specConstCompositeOp) return emitOpError( "expected spv.SpecConstant or spv.SpecConstantComposite symbol"); - if (reference().getType() != constType) + if (getReference().getType() != constType) return emitOpError("result type mismatch with the referenced " "specialization constant's type"); @@ -3521,8 +3525,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::SelectOp::verify() { - if (auto conditionTy = condition().getType().dyn_cast()) { - auto resultVectorTy = result().getType().dyn_cast(); + if (auto conditionTy = getCondition().getType().dyn_cast()) { + auto resultVectorTy = getResult().getType().dyn_cast(); if (!resultVectorTy) { return emitOpError("result expected to be of vector type when " "condition is of vector type"); @@ -3548,7 +3552,7 @@ } void spirv::SelectionOp::print(OpAsmPrinter &printer) { - auto control = selection_control(); + auto control = getSelectionControl(); if (control != spirv::SelectionControl::None) printer << " control(" << spirv::stringifySelectionControl(control) << ")"; printer << ' '; @@ -3598,21 +3602,21 @@ } Block *spirv::SelectionOp::getHeaderBlock() { - assert(!body().empty() && "op region should not be empty!"); + assert(!getBody().empty() && "op region should not be empty!"); // The first block is the loop header block. - return &body().front(); + return &getBody().front(); } Block *spirv::SelectionOp::getMergeBlock() { - assert(!body().empty() && "op region should not be empty!"); + assert(!getBody().empty() && "op region should not be empty!"); // The last block is the loop merge block. - return &body().back(); + return &getBody().back(); } void spirv::SelectionOp::addMergeBlock() { - assert(body().empty() && "entry and merge block already exist"); + assert(getBody().empty() && "entry and merge block already exist"); auto *mergeBlock = new Block(); - body().push_back(mergeBlock); + getBody().push_back(mergeBlock); OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); // Add a spv.mlir.merge op into the merge block. @@ -3682,10 +3686,10 @@ void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { printer << ' '; - printer.printSymbolName(sym_name()); + printer.printSymbolName(getSymName()); if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; - printer << " = " << default_value(); + printer << " = " << getDefaultValue(); } LogicalResult spirv::SpecConstantOp::verify() { @@ -3693,7 +3697,7 @@ if (specID.getValue().isNegative()) return emitOpError("SpecId cannot be negative"); - auto value = default_value(); + auto value = getDefaultValue(); if (value.isa()) { // Make sure bitwidth is allowed. if (!value.getType().isa()) @@ -3732,19 +3736,19 @@ void spirv::StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - ptr().getType().cast().getStorageClass()); - printer << " \"" << sc << "\" " << ptr() << ", " << value(); + getPtr().getType().cast().getStorageClass()); + printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); printMemoryAccessAttribute(*this, printer, elidedAttrs); - printer << " : " << value().getType(); + printer << " : " << getValue().getType(); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); } LogicalResult spirv::StoreOp::verify() { // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an // OpTypePointer whose Type operand is the same as the type of Object." - if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) + if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return verifyMemoryAccessAttribute(*this); } @@ -3819,7 +3823,7 @@ spirv::attributeName()}; // Print optional initializer if (getNumOperands() != 0) - printer << " init(" << initializer() << ")"; + printer << " init(" << getInitializer() << ")"; printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << getType(); @@ -3829,14 +3833,14 @@ // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." - if (storage_class() != spirv::StorageClass::Function) { + if (getStorageClass() != spirv::StorageClass::Function) { return emitOpError( "can only be used to model function-level variables. Use " "spv.GlobalVariable for module-level variables."); } - auto pointerType = pointer().getType().cast(); - if (storage_class() != pointerType.getStorageClass()) + auto pointerType = getPointer().getType().cast(); + if (getStorageClass() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); @@ -3877,17 +3881,17 @@ VectorType resultType = getType().cast(); size_t numResultElements = resultType.getNumElements(); - if (numResultElements != components().size()) + if (numResultElements != getComponents().size()) return emitOpError("result type element count (") << numResultElements << ") mismatch with the number of component selectors (" - << components().size() << ")"; + << getComponents().size() << ")"; size_t totalSrcElements = - vector1().getType().cast().getNumElements() + - vector2().getType().cast().getNumElements(); + getVector1().getType().cast().getNumElements() + + getVector2().getType().cast().getNumElements(); - for (const auto &selector : components().getAsValueRange()) { + for (const auto &selector : getComponents().getAsValueRange()) { uint32_t index = selector.getZExtValue(); if (index >= totalSrcElements && index != std::numeric_limits().max()) @@ -3925,11 +3929,12 @@ } void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { - printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); + printer << " " << getPointer() << ", " << getStride() << ", " + << getColumnmajor(); // Print optional memory access attribute. - if (auto memAccess = memory_access()) + if (auto memAccess = getMemoryAccess()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << pointer().getType() << " as " << getType(); + printer << " : " << getPointer().getType() << " as " << getType(); } static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, @@ -3952,8 +3957,8 @@ } LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { - return verifyPointerAndCoopMatrixType(*this, pointer().getType(), - result().getType()); + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getResult().getType()); } //===----------------------------------------------------------------------===// @@ -3983,17 +3988,17 @@ } void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { - printer << " " << pointer() << ", " << object() << ", " << stride() << ", " - << columnmajor(); + printer << " " << getPointer() << ", " << getObject() << ", " << getStride() + << ", " << getColumnmajor(); // Print optional memory access attribute. - if (auto memAccess = memory_access()) + if (auto memAccess = getMemoryAccess()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); + printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); } LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { - return verifyPointerAndCoopMatrixType(*this, pointer().getType(), - object().getType()); + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getObject().getType()); } //===----------------------------------------------------------------------===// @@ -4002,12 +4007,12 @@ static LogicalResult verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { - if (op.c().getType() != op.result().getType()) + if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.a().getType().cast(); - auto typeB = op.b().getType().cast(); - auto typeC = op.c().getType().cast(); - auto typeR = op.result().getType().cast(); + auto typeA = op.getA().getType().cast(); + auto typeB = op.getB().getType().cast(); + auto typeC = op.getC().getType().cast(); + auto typeR = op.getResult().getType().cast(); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4050,8 +4055,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::INTELJointMatrixLoadOp::verify() { - return verifyPointerAndJointMatrixType(*this, pointer().getType(), - result().getType()); + return verifyPointerAndJointMatrixType(*this, getPointer().getType(), + getResult().getType()); } //===----------------------------------------------------------------------===// @@ -4059,8 +4064,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::INTELJointMatrixStoreOp::verify() { - return verifyPointerAndJointMatrixType(*this, pointer().getType(), - object().getType()); + return verifyPointerAndJointMatrixType(*this, getPointer().getType(), + getObject().getType()); } //===----------------------------------------------------------------------===// @@ -4068,12 +4073,12 @@ //===----------------------------------------------------------------------===// static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { - if (op.c().getType() != op.result().getType()) + if (op.getC().getType() != op.getResult().getType()) return op.emitOpError("result and third operand must have the same type"); - auto typeA = op.a().getType().cast(); - auto typeB = op.b().getType().cast(); - auto typeC = op.c().getType().cast(); - auto typeR = op.result().getType().cast(); + auto typeA = op.getA().getType().cast(); + auto typeB = op.getB().getType().cast(); + auto typeC = op.getC().getType().cast(); + auto typeR = op.getResult().getType().cast(); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) @@ -4100,11 +4105,11 @@ // We already checked that result and matrix are both of matrix type in the // auto-generated verify method. - auto inputMatrix = matrix().getType().cast(); - auto resultMatrix = result().getType().cast(); + auto inputMatrix = getMatrix().getType().cast(); + auto resultMatrix = getResult().getType().cast(); // Check that the scalar type is the same as the matrix element type. - if (scalar().getType() != inputMatrix.getElementType()) + if (getScalar().getType() != inputMatrix.getElementType()) return emitError("input matrix components' type and scaling value must " "have the same type"); @@ -4137,22 +4142,23 @@ printer << ' '; StringRef targetStorageClass = stringifyStorageClass( - target().getType().cast().getStorageClass()); - printer << " \"" << targetStorageClass << "\" " << target() << ", "; + getTarget().getType().cast().getStorageClass()); + printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; StringRef sourceStorageClass = stringifyStorageClass( - source().getType().cast().getStorageClass()); - printer << " \"" << sourceStorageClass << "\" " << source(); + getSource().getType().cast().getStorageClass()); + printer << " \"" << sourceStorageClass << "\" " << getSource(); SmallVector elidedAttrs; printMemoryAccessAttribute(*this, printer, elidedAttrs); printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, - source_memory_access(), source_alignment()); + getSourceMemoryAccess(), + getSourceAlignment()); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = - target().getType().cast().getPointeeType(); + getTarget().getType().cast().getPointeeType(); printer << " : " << pointeeType; } @@ -4200,10 +4206,10 @@ LogicalResult spirv::CopyMemoryOp::verify() { Type targetType = - target().getType().cast().getPointeeType(); + getTarget().getType().cast().getPointeeType(); Type sourceType = - source().getType().cast().getPointeeType(); + getSource().getType().cast().getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); @@ -4227,8 +4233,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { - auto inputMatrix = matrix().getType().cast(); - auto resultMatrix = result().getType().cast(); + auto inputMatrix = getMatrix().getType().cast(); + auto resultMatrix = getResult().getType().cast(); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) @@ -4252,9 +4258,9 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { - auto leftMatrix = leftmatrix().getType().cast(); - auto rightMatrix = rightmatrix().getType().cast(); - auto resultMatrix = result().getType().cast(); + auto leftMatrix = getLeftmatrix().getType().cast(); + auto rightMatrix = getRightmatrix().getType().cast(); + auto resultMatrix = getResult().getType().cast(); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) @@ -4329,23 +4335,23 @@ void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; - printer.printSymbolName(sym_name()); + printer.printSymbolName(getSymName()); printer << " ("; - auto constituents = this->constituents().getValue(); + auto constituents = this->getConstituents().getValue(); if (!constituents.empty()) llvm::interleaveComma(constituents, printer); - printer << ") : " << type(); + printer << ") : " << getType(); } LogicalResult spirv::SpecConstantCompositeOp::verify() { - auto cType = type().dyn_cast(); - auto constituents = this->constituents().getValue(); + auto cType = getType().dyn_cast(); + auto constituents = this->getConstituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") - << type(); + << getType(); if (cType.isa()) return emitError("unsupported composite type ") << cType; @@ -4363,11 +4369,11 @@ dyn_cast(SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), constituent.getAttr())); - if (constituentSpecConstOp.default_value().getType() != + if (constituentSpecConstOp.getDefaultValue().getType() != cType.getElementType(index)) return emitError("has incorrect types of operands: expected ") << cType.getElementType(index) << ", but provided " - << constituentSpecConstOp.default_value().getType(); + << constituentSpecConstOp.getDefaultValue().getType(); } return success(); @@ -4406,7 +4412,7 @@ void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { printer << " wraps "; - printer.printGenericOp(&body().front().front()); + printer.printGenericOp(&getBody().front().front()); } LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { @@ -4434,7 +4440,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GLFrexpStructOp::verify() { - spirv::StructType structTy = result().getType().dyn_cast(); + spirv::StructType structTy = + getResult().getType().dyn_cast(); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); @@ -4444,7 +4451,7 @@ VectorType exponentVecTy = exponentTy.dyn_cast(); IntegerType exponentIntTy = exponentTy.dyn_cast(); - Type operandTy = operand().getType(); + Type operandTy = getOperand().getType(); VectorType operandVecTy = operandTy.dyn_cast(); FloatType operandFTy = operandTy.dyn_cast(); @@ -4480,8 +4487,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::GLLdexpOp::verify() { - Type significandType = x().getType(); - Type exponentType = exp().getType(); + Type significandType = getX().getType(); + Type exponentType = getExp().getType(); if (significandType.isa() != exponentType.isa()) return emitOpError("operands must both be scalars or vectors"); @@ -4503,9 +4510,9 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageDrefGatherOp::verify() { - VectorType resultType = result().getType().cast(); + VectorType resultType = getResult().getType().cast(); auto sampledImageType = - sampledimage().getType().cast(); + getSampledimage().getType().cast(); auto imageType = sampledImageType.getImageType().cast(); if (resultType.getNumElements() != 4) @@ -4530,8 +4537,8 @@ if (imageMS != spirv::ImageSamplingInfo::SingleSampled) return emitOpError("the MS operand of the underlying image type must be 0"); - spirv::ImageOperandsAttr attr = imageoperandsAttr(); - auto operandArguments = operand_arguments(); + spirv::ImageOperandsAttr attr = getImageoperandsAttr(); + auto operandArguments = getOperandArguments(); return verifyImageOperands(*this, attr, operandArguments); } @@ -4565,8 +4572,8 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { - spirv::ImageType imageType = image().getType().cast(); - Type resultType = result().getType(); + spirv::ImageType imageType = getImage().getType().cast(); + Type resultType = getResult().getType(); spirv::Dim dim = imageType.getDim(); spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); @@ -4668,9 +4675,9 @@ template static auto concatElemAndIndices(Op op) { - SmallVector ret(op.indices().size() + 1); - ret[0] = op.element(); - llvm::copy(op.indices(), ret.begin() + 1); + SmallVector ret(op.getIndices().size() + 1); + ret[0] = op.getElement(); + llvm::copy(op.getIndices(), ret.begin() + 1); return ret; } @@ -4698,7 +4705,7 @@ } LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { - return verifyAccessChain(*this, indices()); + return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// @@ -4724,7 +4731,7 @@ } LogicalResult spirv::PtrAccessChainOp::verify() { - return verifyAccessChain(*this, indices()); + return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// @@ -4732,10 +4739,10 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::VectorTimesScalarOp::verify() { - if (vector().getType() != getType()) + if (getVector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); auto scalarType = getType().cast().getElementType(); - if (scalar().getType() != scalarType) + if (getScalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -94,16 +94,16 @@ return nullptr; spirv::ModuleOp firstModule = inputModules.front(); - auto addressingModel = firstModule.addressing_model(); - auto memoryModel = firstModule.memory_model(); - auto vceTriple = firstModule.vce_triple(); + auto addressingModel = firstModule.getAddressingModel(); + auto memoryModel = firstModule.getMemoryModel(); + auto vceTriple = firstModule.getVceTriple(); // First check whether there are conflicts between addressing/memory model. // Return early if so. for (auto module : inputModules) { - if (module.addressing_model() != addressingModel || - module.memory_model() != memoryModel || - module.vce_triple() != vceTriple) { + if (module.getAddressingModel() != addressingModel || + module.getMemoryModel() != memoryModel || + module.getVceTriple() != vceTriple) { module.emitError("input modules differ in addressing model, memory " "model, and/or VCE triple"); return nullptr; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -40,7 +40,7 @@ PatternRewriter &rewriter) const override { SmallVector globalVarAttrs; - auto ptrType = op.type().cast(); + auto ptrType = op.getType().cast(); auto structType = VulkanLayoutUtils::decorateType( ptrType.getPointeeType().cast()); @@ -71,11 +71,11 @@ LogicalResult matchAndRewrite(spirv::AddressOfOp op, PatternRewriter &rewriter) const override { auto spirvModule = op->getParentOfType(); - auto varName = op.variableAttr(); + auto varName = op.getVariableAttr(); auto varOp = spirvModule.lookupSymbol(varName); rewriter.replaceOpWithNewOp( - op, varOp.type(), SymbolRefAttr::get(varName.getAttr())); + op, varOp.getType(), SymbolRefAttr::get(varName.getAttr())); return success(); } }; @@ -121,12 +121,12 @@ target.addLegalOp(); target.addDynamicallyLegalOp( [](spirv::GlobalVariableOp op) { - return VulkanLayoutUtils::isLegalType(op.type()); + return VulkanLayoutUtils::isLegalType(op.getType()); }); // Change the type for the direct users. target.addDynamicallyLegalOp([](spirv::AddressOfOp op) { - return VulkanLayoutUtils::isLegalType(op.pointer().getType()); + return VulkanLayoutUtils::isLegalType(op.getPointer().getType()); }); // Change the type for the indirect users. @@ -134,7 +134,8 @@ spirv::StoreOp>([&](Operation *op) { for (Value operand : op->getOperands()) { auto addrOp = operand.getDefiningOp(); - if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType())) + if (addrOp && + !VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType())) return false; } return true; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -88,13 +88,13 @@ // instructions in this function. funcOp.walk([&](spirv::AddressOfOp addressOfOp) { auto var = - module.lookupSymbol(addressOfOp.variable()); + module.lookupSymbol(addressOfOp.getVariable()); // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s // storage classes are limited to the Input and Output storage classes. // Starting with version 1.4, the interface’s storage classes are all // storage classes used in declaring all global variables referenced by the // entry point’s call tree." We should consider the target environment here. - switch (var.type().cast().getStorageClass()) { + switch (var.getType().cast().getStorageClass()) { case spirv::StorageClass::Input: case spirv::StorageClass::Output: interfaceVarSet.insert(var.getOperation()); @@ -105,7 +105,7 @@ }); for (auto &var : interfaceVarSet) { interfaceVars.push_back(SymbolRefAttr::get( - funcOp.getContext(), cast(var).sym_name())); + funcOp.getContext(), cast(var).getSymName())); } return success(); } @@ -223,7 +223,7 @@ auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create( - funcOp.getLoc(), replacement, zero.constant()); + funcOp.getLoc(), replacement, zero.getConstant()); replacement = rewriter.create(funcOp.getLoc(), loadPtr); } signatureConverter.remapInput(argType.index(), replacement); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -63,7 +63,7 @@ SmallVector operands; // Collect inserted objects. for (auto insertionOp : insertions) - operands.push_back(insertionOp.object()); + operands.push_back(insertionOp.getObject()); OpBuilder builder(lastCompositeInsertOp); auto compositeConstructOp = builder.create( @@ -84,11 +84,13 @@ LogicalResult RewriteInsertsPass::collectInsertionChain( spirv::CompositeInsertOp op, SmallVectorImpl &insertions) { - auto indicesArrayAttr = op.indices().cast(); + auto indicesArrayAttr = op.getIndices().cast(); // TODO: handle nested composite object. if (indicesArrayAttr.size() == 1) { - auto numElements = - op.composite().getType().cast().getNumElements(); + auto numElements = op.getComposite() + .getType() + .cast() + .getNumElements(); auto index = indicesArrayAttr[0].cast().getInt(); // Need a last index to collect a sequential chain. @@ -102,12 +104,12 @@ if (index == 0) return success(); - op = op.composite().getDefiningOp(); + op = op.getComposite().getDefiningOp(); if (!op) return failure(); --index; - indicesArrayAttr = op.indices().cast(); + indicesArrayAttr = op.getIndices().cast(); if ((indicesArrayAttr.size() != 1) || (indicesArrayAttr[0].cast().getInt() != index)) return failure(); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -642,7 +642,7 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { for (auto varOp : body.getOps()) { - auto ptrType = varOp.type().dyn_cast(); + auto ptrType = varOp.getType().dyn_cast(); if (!ptrType) continue; @@ -874,7 +874,7 @@ // Special treatment for global variables, whose type requirements are // conveyed by type attributes. if (auto globalVar = dyn_cast(op)) - valueTypes.push_back(globalVar.type()); + valueTypes.push_back(globalVar.getType()); // Make sure the op's operands/results use types that are allowed by the // target environment. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -51,8 +51,8 @@ AliasedResourceMap aliasedResources; moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) { if (varOp->getAttrOfType("aliased")) { - Optional set = varOp.descriptor_set(); - Optional binding = varOp.binding(); + Optional set = varOp.getDescriptorSet(); + Optional binding = varOp.getBinding(); if (set && binding) aliasedResources[{*set, *binding}].push_back(varOp); } @@ -222,16 +222,16 @@ } if (auto addressOp = dyn_cast(op)) { auto moduleOp = addressOp->getParentOfType(); - auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); + auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()); return shouldUnify(varOp); } if (auto acOp = dyn_cast(op)) - return shouldUnify(acOp.base_ptr().getDefiningOp()); + return shouldUnify(acOp.getBasePtr().getDefiningOp()); if (auto loadOp = dyn_cast(op)) - return shouldUnify(loadOp.ptr().getDefiningOp()); + return shouldUnify(loadOp.getPtr().getDefiningOp()); if (auto storeOp = dyn_cast(op)) - return shouldUnify(storeOp.ptr().getDefiningOp()); + return shouldUnify(storeOp.getPtr().getDefiningOp()); return false; } @@ -265,7 +265,7 @@ // Collect the element types for all resources in the current set. SmallVector elementTypes; for (spirv::GlobalVariableOp resource : resources) { - Type elementType = getRuntimeArrayElementType(resource.type()); + Type elementType = getRuntimeArrayElementType(resource.getType()); if (!elementType) return; // Unexpected resource variable type. @@ -326,7 +326,7 @@ // Rewrite the AddressOf op to get the address of the canoncical resource. auto moduleOp = addressOp->getParentOfType(); auto srcVarOp = cast( - SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); + SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable())); auto dstVarOp = analysis.getCanonicalResource(srcVarOp); rewriter.replaceOpWithNewOp(addressOp, dstVarOp); return success(); @@ -339,13 +339,13 @@ LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto addressOp = acOp.base_ptr().getDefiningOp(); + auto addressOp = acOp.getBasePtr().getDefiningOp(); if (!addressOp) return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); auto moduleOp = acOp->getParentOfType(); auto srcVarOp = cast( - SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); + SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable())); auto dstVarOp = analysis.getCanonicalResource(srcVarOp); spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); @@ -356,7 +356,7 @@ // We have the same bitwidth for source and destination element types. // Thie indices keep the same. rewriter.replaceOpWithNewOp( - acOp, adaptor.base_ptr(), adaptor.indices()); + acOp, adaptor.getBasePtr(), adaptor.getIndices()); return success(); } @@ -375,7 +375,7 @@ auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); - auto indices = llvm::to_vector<4>(acOp.indices()); + auto indices = llvm::to_vector<4>(acOp.getIndices()); Value oldIndex = indices.back(); indices.back() = rewriter.create(loc, i32Type, oldIndex, ratioValue); @@ -383,7 +383,7 @@ rewriter.create(loc, i32Type, oldIndex, ratioValue)); rewriter.replaceOpWithNewOp( - acOp, adaptor.base_ptr(), indices); + acOp, adaptor.getBasePtr(), indices); return success(); } @@ -399,13 +399,13 @@ auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); - auto indices = llvm::to_vector<4>(acOp.indices()); + auto indices = llvm::to_vector<4>(acOp.getIndices()); Value oldIndex = indices.back(); indices.back() = rewriter.create(loc, i32Type, oldIndex, ratioValue); rewriter.replaceOpWithNewOp( - acOp, adaptor.base_ptr(), indices); + acOp, adaptor.getBasePtr(), indices); return success(); } @@ -420,13 +420,13 @@ LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcPtrType = loadOp.ptr().getType().cast(); + auto srcPtrType = loadOp.getPtr().getType().cast(); auto srcElemType = srcPtrType.getPointeeType().cast(); - auto dstPtrType = adaptor.ptr().getType().cast(); + auto dstPtrType = adaptor.getPtr().getType().cast(); auto dstElemType = dstPtrType.getPointeeType().cast(); Location loc = loadOp.getLoc(); - auto newLoadOp = rewriter.create(loc, adaptor.ptr()); + auto newLoadOp = rewriter.create(loc, adaptor.getPtr()); if (srcElemType == dstElemType) { rewriter.replaceOp(loadOp, newLoadOp->getResults()); return success(); @@ -434,7 +434,7 @@ if (areSameBitwidthScalarType(srcElemType, dstElemType)) { auto castOp = rewriter.create(loc, srcElemType, - newLoadOp.value()); + newLoadOp.getValue()); rewriter.replaceOp(loadOp, castOp->getResults()); return success(); @@ -457,19 +457,19 @@ components.reserve(ratio); components.push_back(newLoadOp); - auto acOp = adaptor.ptr().getDefiningOp(); + auto acOp = adaptor.getPtr().getDefiningOp(); if (!acOp) return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); auto i32Type = rewriter.getI32Type(); Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); - auto indices = llvm::to_vector<4>(acOp.indices()); + auto indices = llvm::to_vector<4>(acOp.getIndices()); for (int i = 1; i < ratio; ++i) { // Load all subsequent components belonging to this element. indices.back() = rewriter.create( loc, i32Type, indices.back(), oneValue); auto componentAcOp = rewriter.create( - loc, acOp.base_ptr(), indices); + loc, acOp.getBasePtr(), indices); // Assuming little endian, this reads lower-ordered bits of the number // to lower-numbered components of the vector. components.push_back( @@ -504,19 +504,19 @@ matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcElemType = - storeOp.ptr().getType().cast().getPointeeType(); + storeOp.getPtr().getType().cast().getPointeeType(); auto dstElemType = - adaptor.ptr().getType().cast().getPointeeType(); + adaptor.getPtr().getType().cast().getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); if (!areSameBitwidthScalarType(srcElemType, dstElemType)) return rewriter.notifyMatchFailure(storeOp, "different bitwidth"); Location loc = storeOp.getLoc(); - Value value = adaptor.value(); + Value value = adaptor.getValue(); if (srcElemType != dstElemType) value = rewriter.create(loc, dstElemType, value); - rewriter.replaceOpWithNewOp(storeOp, adaptor.ptr(), value, + rewriter.replaceOpWithNewOp(storeOp, adaptor.getPtr(), value, storeOp->getAttrs()); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -151,7 +151,7 @@ // Special treatment for global variables, whose type requirements are // conveyed by type attributes. if (auto globalVar = dyn_cast(op)) - valueTypes.push_back(globalVar.type()); + valueTypes.push_back(globalVar.getType()); // Requirements from values' types SmallVector, 4> typeExtensions; diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -46,20 +46,20 @@ } if (auto varOp = getGlobalVariable(id)) { auto addressOfOp = opBuilder.create( - unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); - return addressOfOp.pointer(); + unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation())); + return addressOfOp.getPointer(); } if (auto constOp = getSpecConstant(id)) { auto referenceOfOp = opBuilder.create( - unknownLoc, constOp.default_value().getType(), + unknownLoc, constOp.getDefaultValue().getType(), SymbolRefAttr::get(constOp.getOperation())); - return referenceOfOp.reference(); + return referenceOfOp.getReference(); } if (auto constCompositeOp = getSpecConstantComposite(id)) { auto referenceOfOp = opBuilder.create( - unknownLoc, constCompositeOp.type(), + unknownLoc, constCompositeOp.getType(), SymbolRefAttr::get(constCompositeOp.getOperation())); - return referenceOfOp.reference(); + return referenceOfOp.getReference(); } if (auto specConstOperationInfo = getSpecConstantOperation(id)) { return materializeSpecConstantOperation( diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1414,7 +1414,7 @@ auto specConstOperationOp = opBuilder.create(loc, resultType); - Region &body = specConstOperationOp.body(); + Region &body = specConstOperationOp.getBody(); // Move the new block into SpecConstantOperation's body. body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), Region::iterator(enclosedBlock)); @@ -1983,17 +1983,17 @@ assert((branchCondOp.getTrueBlock() == target || branchCondOp.getFalseBlock() == target) && "expected target to be either the true or false target"); - if (target == branchCondOp.trueTarget()) + if (target == branchCondOp.getTrueTarget()) opBuilder.create( - branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, + branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs, branchCondOp.getFalseBlockArguments(), - branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), - branchCondOp.falseTarget()); + branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(), + branchCondOp.getFalseTarget()); else opBuilder.create( - branchCondOp.getLoc(), branchCondOp.condition(), + branchCondOp.getLoc(), branchCondOp.getCondition(), branchCondOp.getTrueBlockArguments(), blockArgs, - branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), + branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(), branchCondOp.getFalseBlock()); branchCondOp.erase(); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp @@ -24,7 +24,7 @@ LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary, const SerializationOptions &options) { - if (!module.vce_triple()) + if (!module.getVceTriple()) return module.emitError( "module must have 'vce_triple' attribute to be serializeable"); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -58,7 +58,8 @@ namespace mlir { namespace spirv { LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { - if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { + if (auto resultID = + prepareConstant(op.getLoc(), op.getType(), op.getValue())) { valueIDMap[op.getResult()] = resultID; return success(); } @@ -66,7 +67,7 @@ } LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { - if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), + if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(), /*isSpec=*/true)) { // Emit the OpDecorate instruction for SpecId. if (auto specID = op->getAttrOfType("spec_id")) { @@ -75,8 +76,8 @@ return failure(); } - specConstIDMap[op.sym_name()] = resultID; - return processName(resultID, op.sym_name()); + specConstIDMap[op.getSymName()] = resultID; + return processName(resultID, op.getSymName()); } return failure(); } @@ -84,7 +85,7 @@ LogicalResult Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { uint32_t typeID = 0; - if (failed(processType(op.getLoc(), op.type(), typeID))) { + if (failed(processType(op.getLoc(), op.getType(), typeID))) { return failure(); } @@ -94,7 +95,7 @@ operands.push_back(typeID); operands.push_back(resultID); - auto constituents = op.constituents(); + auto constituents = op.getConstituents(); for (auto index : llvm::seq(0, constituents.size())) { auto constituent = constituents[index].dyn_cast(); @@ -112,9 +113,9 @@ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantComposite, operands); - specConstIDMap[op.sym_name()] = resultID; + specConstIDMap[op.getSymName()] = resultID; - return processName(resultID, op.sym_name()); + return processName(resultID, op.getSymName()); } LogicalResult @@ -199,7 +200,7 @@ operands.push_back(resTypeID); auto funcID = getOrCreateFunctionID(op.getName()); operands.push_back(funcID); - operands.push_back(static_cast(op.function_control())); + operands.push_back(static_cast(op.getFunctionControl())); operands.push_back(fnTypeID); encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); @@ -310,7 +311,7 @@ // Get TypeID. uint32_t resultTypeID = 0; SmallVector elidedAttrs; - if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { + if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) { return failure(); } @@ -320,7 +321,7 @@ auto resultID = getNextID(); // Encode the name. - auto varName = varOp.sym_name(); + auto varName = varOp.getSymName(); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); if (failed(processName(resultID, varName))) { return failure(); @@ -332,7 +333,7 @@ operands.push_back(static_cast(varOp.storageClass())); // Encode initialization. - if (auto initializer = varOp.initializer()) { + if (auto initializer = varOp.getInitializer()) { auto initializerID = getVariableID(*initializer); if (!initializerID) { return emitError(varOp.getLoc(), @@ -364,7 +365,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // Assign s to all blocks so that branches inside the SelectionOp can // resolve properly. - auto &body = selectionOp.body(); + auto &body = selectionOp.getBody(); for (Block &block : body) getOrCreateBlockID(&block); @@ -390,7 +391,7 @@ lastProcessedWasMergeInst = true; encodeInstructionInto( functionBody, spirv::Opcode::OpSelectionMerge, - {mergeID, static_cast(selectionOp.selection_control())}); + {mergeID, static_cast(selectionOp.getSelectionControl())}); return success(); }; if (failed( @@ -420,7 +421,7 @@ // Assign s to all blocks so that branches inside the LoopOp can resolve // properly. We don't need to assign for the entry block, which is just for // satisfying MLIR region's structural requirement. - auto &body = loopOp.body(); + auto &body = loopOp.getBody(); for (Block &block : llvm::drop_begin(body)) getOrCreateBlockID(&block); @@ -452,7 +453,7 @@ lastProcessedWasMergeInst = true; encodeInstructionInto( functionBody, spirv::Opcode::OpLoopMerge, - {mergeID, continueID, static_cast(loopOp.loop_control())}); + {mergeID, continueID, static_cast(loopOp.getLoopControl())}); return success(); }; if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) @@ -483,12 +484,12 @@ LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { - auto conditionID = getValueID(condBranchOp.condition()); + auto conditionID = getValueID(condBranchOp.getCondition()); auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); SmallVector arguments{conditionID, trueLabelID, falseLabelID}; - if (auto weights = condBranchOp.branch_weights()) { + if (auto weights = condBranchOp.getBranchWeights()) { for (auto val : weights->getValue()) arguments.push_back(val.cast().getInt()); } @@ -509,26 +510,26 @@ } LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { - auto varName = addressOfOp.variable(); + auto varName = addressOfOp.getVariable(); auto variableID = getVariableID(varName); if (!variableID) { return addressOfOp.emitError("unknown result for variable ") << varName; } - valueIDMap[addressOfOp.pointer()] = variableID; + valueIDMap[addressOfOp.getPointer()] = variableID; return success(); } LogicalResult Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { - auto constName = referenceOfOp.spec_const(); + auto constName = referenceOfOp.getSpecConst(); auto constID = getSpecConstID(constName); if (!constID) { return referenceOfOp.emitError( "unknown result for specialization constant ") << constName; } - valueIDMap[referenceOfOp.reference()] = constID; + valueIDMap[referenceOfOp.getReference()] = constID; return success(); } @@ -537,21 +538,21 @@ Serializer::processOp(spirv::EntryPointOp op) { SmallVector operands; // Add the ExecutionModel. - operands.push_back(static_cast(op.execution_model())); + operands.push_back(static_cast(op.getExecutionModel())); // Add the function . - auto funcID = getFunctionID(op.fn()); + auto funcID = getFunctionID(op.getFn()); if (!funcID) { return op.emitError("missing for function ") - << op.fn() + << op.getFn() << "; function needs to be defined before spv.EntryPoint is " "serialized"; } operands.push_back(funcID); // Add the name of the function. - spirv::encodeStringLiteralInto(operands, op.fn()); + spirv::encodeStringLiteralInto(operands, op.getFn()); // Add the interface values. - if (auto interface = op.interface()) { + if (auto interface = op.getInterface()) { for (auto var : interface.getValue()) { auto id = getVariableID(var.cast().getValue()); if (!id) { @@ -571,19 +572,19 @@ Serializer::processOp(spirv::ExecutionModeOp op) { SmallVector operands; // Add the function . - auto funcID = getFunctionID(op.fn()); + auto funcID = getFunctionID(op.getFn()); if (!funcID) { return op.emitError("missing for function ") - << op.fn() + << op.getFn() << "; function needs to be serialized before ExecutionModeOp is " "serialized"; } operands.push_back(funcID); // Add the ExecutionMode. - operands.push_back(static_cast(op.execution_mode())); + operands.push_back(static_cast(op.getExecutionMode())); // Serialize values if any. - auto values = op.values(); + auto values = op.getValues(); if (values) { for (auto &intVal : values.getValue()) { operands.push_back(static_cast( @@ -598,7 +599,7 @@ template <> LogicalResult Serializer::processOp(spirv::FunctionCallOp op) { - auto funcName = op.callee(); + auto funcName = op.getCallee(); uint32_t resTypeID = 0; Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); @@ -609,7 +610,7 @@ auto funcCallID = getNextID(); SmallVector operands{resTypeID, funcCallID, funcID}; - for (auto value : op.arguments()) { + for (auto value : op.getArguments()) { auto valueID = getValueID(value); assert(valueID && "cannot find a value for spv.FunctionCall"); operands.push_back(valueID); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -119,7 +119,8 @@ binary.clear(); binary.reserve(moduleSize); - spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); + spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(), + nextID); binary.append(capabilities.begin(), capabilities.end()); binary.append(extensions.begin(), extensions.end()); binary.append(extendedSets.begin(), extendedSets.end()); @@ -166,7 +167,7 @@ } void Serializer::processCapability() { - for (auto cap : module.vce_triple()->getCapabilities()) + for (auto cap : module.getVceTriple()->getCapabilities()) encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, {static_cast(cap)}); } @@ -186,7 +187,7 @@ void Serializer::processExtension() { llvm::SmallVector extName; - for (spirv::Extension ext : module.vce_triple()->getExtensions()) { + for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { extName.clear(); spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); @@ -1045,11 +1046,11 @@ } else if (auto branchCondOp = dyn_cast(terminator)) { Optional blockOperands; - if (branchCondOp.trueTarget() == block) { - blockOperands = branchCondOp.trueTargetOperands(); + if (branchCondOp.getTrueTarget() == block) { + blockOperands = branchCondOp.getTrueTargetOperands(); } else { - assert(branchCondOp.falseTarget() == block); - blockOperands = branchCondOp.falseTargetOperands(); + assert(branchCondOp.getFalseTarget() == block); + blockOperands = branchCondOp.getFalseTargetOperands(); } assert(!blockOperands->empty() && diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -48,7 +48,7 @@ using mlir::tblgen::NamespaceEmitter; using mlir::tblgen::Operator; -//===----------------------------------------------------------------------===// +//===-------------------------------------tblgen_attrVal---------------------------------===// // Availability Wrapper Class //===----------------------------------------------------------------------===// @@ -1360,7 +1360,7 @@ os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " "static_cast<{0}::{1}>(1 << i);\n", enumAttr.getCppNamespace(), enumAttr.getEnumClassName(), - namedAttr.name); + srcOp.getGetterName(namedAttr.name)); os << formatv( " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", enumAttr.getUnderlyingType()); @@ -1368,7 +1368,7 @@ // For IntEnumAttr, we just need to query the value as a whole. os << " {\n"; os << formatv(" auto tblgen_attrVal = this->{0}();\n", - namedAttr.name); + srcOp.getGetterName(namedAttr.name)); } os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", enumAttr.getCppNamespace(), avail.getQueryFnName());