diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -18,6 +18,12 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/SymbolInterfaces.td" +/// A TypeAttr for memref types. +def MemRefTypeAttr + : TypeAttrBase<"::mlir::MemRefType", "memref type attribute"> { + let constBuilderCall = "::mlir::TypeAttr::get($0)"; +} + class MemRef_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; @@ -597,14 +603,14 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> { let summary = "declare or define a global memref variable"; let description = [{ - The `memref.global` operation declares or defines a named global variable. - The backing memory for the variable is allocated statically and is described - by the type of the variable (which should be a statically shaped memref - type). The operation is a declaration if no `inital_value` is specified, - else it is a definition. The `initial_value` can either be a unit attribute - to represent a definition of an uninitialized global variable, or an - elements attribute to represent the definition of a global variable with an - initial value. The global variable can also be marked constant using the + The `memref.global` operation declares or defines a named global memref + variable. The backing memory for the variable is allocated statically and is + described by the type of the variable (which should be a statically shaped + memref type). The operation is a declaration if no `inital_value` is + specified, else it is a definition. The `initial_value` can either be a unit + attribute to represent a definition of an uninitialized global variable, or + an elements attribute to represent the definition of a global variable with + an initial value. The global variable can also be marked constant using the `constant` unit attribute. Writing to such constant global variables is undefined. @@ -633,7 +639,7 @@ let arguments = (ins SymbolNameAttr:$sym_name, OptionalAttr:$sym_visibility, - TypeAttr:$type, + MemRefTypeAttr:$type, OptionalAttr:$initial_value, UnitAttr:$constant ); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -434,7 +434,7 @@ LogicalResult matchAndRewrite(memref::GlobalOp global, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - MemRefType type = global.type().cast(); + MemRefType type = global.type(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -46,7 +46,7 @@ auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), - /*type=*/typeConverter.convertType(type), + /*type=*/typeConverter.convertType(type).cast(), /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true); symbolTable.insert(global);