Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
# See https://llvm.org/LICENSE.txt for license information. | # See https://llvm.org/LICENSE.txt for license information. | ||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
from typing import Optional, Sequence, Union | |||||
from ..ir import * | |||||
from ._ods_common import get_default_loc_context | |||||
class InitTensorOp: | |||||
mehdi_amini: Can you add a class doc? Even just to redirect to https://mlir.llvm. | |||||
Done. (also realized that I forgot to document the extension mechanism. will do that in a followup) stellaraccident: Done. (also realized that I forgot to document the extension mechanism. will do that in a… | |||||
"""Extends the linalg.init_tensor op.""" | |||||
def __init__(self, | |||||
sizes: Union[Sequence[int], Sequence[Value]], | |||||
element_type: Type, | |||||
*, | |||||
loc=None, | |||||
ip=None): | |||||
"""Constructs an `init_tensor` with either static or dynamic sizes.""" | |||||
context = get_default_loc_context(loc) | |||||
operands = [] | |||||
attributes = {} | |||||
# TODO: Refactor the InitTensorOp to take an element type attribute and | |||||
# then use normal result type inference, unifying the Python and C++ side | |||||
# with a standard mechanism (versus stashing that in builders). | |||||
if sizes and isinstance(sizes[0], Value): | |||||
# Dynamic sizes. | |||||
operands.extend(sizes) | |||||
static_size_ints = [-1] * len(sizes) | |||||
result_type = RankedTensorType.get(static_size_ints, element_type) | |||||
else: | |||||
# Static sizes. | |||||
result_type = RankedTensorType.get(sizes, element_type) | |||||
static_size_ints = sizes | |||||
index_type = IndexType.get(context) | |||||
attributes["static_sizes"] = ArrayAttr.get( | |||||
[IntegerAttr.get(index_type, s) for s in static_size_ints], | |||||
context=context) | |||||
op = self.build_generic(results=[result_type], | |||||
operands=operands, | |||||
attributes=attributes, | |||||
loc=loc, | |||||
ip=ip) | |||||
OpView.__init__(self, op) | |||||
class StructuredOpMixin: | class StructuredOpMixin: | ||||
"""All structured ops use the same mixin class.""" | """All structured ops use the same mixin class.""" | ||||
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): | def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): | ||||
if outputs and results: | if outputs and results: | ||||
raise ValueError( | raise ValueError( | ||||
"Structured ops must have outputs or results, but not both.") | "Structured ops must have outputs or results, but not both.") | ||||
Show All 15 Lines |
Can you add a class doc? Even just to redirect to https://mlir.llvm.org/docs/Dialects/Linalg/#linalginit_tensor-mlirlinalginittensorop or at minimum mentioned that it maps to the linalg.init_tensor op.