This is an archive of the discontinued LLVM Phabricator instance.

[mlir][Python][Linalg] Adding const, capture, and index support to the OpDSL.
ClosedPublic

Authored by gysit on Apr 27 2021, 7:04 AM.

Details

Summary

The patch extends the OpDSL with support for:

  • Constant values
  • Capture scalar parameters
  • Access the iteration indices using the index operation
  • Provide predefined floating point and integer types.

Up to now the patch only supports emitting the new nodes. The C++/yaml path is not fully implemented (there is no implementation at all on the C++ side) and untested. The fill_rng_2d operation defined in emit_structured_generic.py makes use of the new DSL constructs.

Diff Detail

Event Timeline

gysit created this revision.Apr 27 2021, 7:04 AM
gysit requested review of this revision.Apr 27 2021, 7:04 AM
gysit added a comment.Apr 27 2021, 7:11 AM

The fill_rng_2d operation defined below illustrates the new features implemented in this revision:

@linalg_structured_op
def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
                min=CaptureDef(F64),
                max=CaptureDef(F64),
                seed=CaptureDef(I32)):
  multiplier = const(I32, 1103515245)
  increment = const(I32, 12345)
  temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
  temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment
  scaling = (max - min) * const(F64, 2.3283064e-10)
  A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min)

Running /usr/bin/python3.9 /usr/local/google/home/gysit/Repos/llvm-project/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py emits the following ir:

func @test_fill_rng_2d(%arg0: tensor<4x16xi32>, %arg1: f64, %arg2: f64, %arg3: i32) -> tensor<4x16xi32> {
  %0 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<4x16xi32>) {
  ^bb0(%arg4: i32):  // no predecessors
    %1 = linalg.index 1 : index
    %2 = index_cast %1 : index to i32
    %3 = linalg.index 0 : index
    %4 = index_cast %3 : index to i32
    %5 = addi %4, %arg3 : i32
    %c1103515245_i32 = constant 1103515245 : i32
    %6 = muli %5, %c1103515245_i32 : i32
    %c12345_i32 = constant 12345 : i32
    %7 = addi %6, %c12345_i32 : i32
    %8 = addi %2, %7 : i32
    %c1103515245_i32_0 = constant 1103515245 : i32
    %9 = muli %8, %c1103515245_i32_0 : i32
    %c12345_i32_1 = constant 12345 : i32
    %10 = addi %9, %c12345_i32_1 : i32
    %11 = sitofp %10 : i32 to f64
    %12 = subf %arg2, %arg1 : f64
    %cst = constant 2.3283063999999999E-10 : f64
    %13 = mulf %12, %cst : f64
    %14 = mulf %11, %13 : f64
    %15 = addf %14, %arg1 : f64
    %16 = fptosi %15 : f64 to i32
    linalg.yield %16 : i32
  } -> tensor<4x16xi32>
  return %0 : tensor<4x16xi32>
}

I expect some iterations will be needed to make the code stable and to add support on the C++ side.

stellaraccident accepted this revision.Apr 27 2021, 9:22 PM

This looks good to me - will leave the rest of the review to Nicolas. We do have docs in the docs/ tree for this language, and it would be good to update them.

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
200

I'm trying to understand what a CaptureDef can be: I assume from the examples that it must be a primitive/scalar that is supplied as an SSA value. Can it be anything else (i.e. a whole-tensor value that does not participate in indexing/iteration)? Do you foresee this being different from captured attributes (which the old language supports)?

(you don't need to answer me inline if the questions have clear answers: just extend the docs, either here or in the language guide in the docs/ folder)

338

Can you say more about what the allowable type/value are and how this is emitted?

340

I think you want Any here. We likely need to further constrain it later but can let it be anything for now.

350

Somewhat obvious, but can you document more what the dimension index is resolved against?

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
25

Can we just name this ScalarCapture (not abbreviated) for consistency?

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
13 ↗(On Diff #340827)

Mind submitting these changes NFC separately?

This revision is now accepted and ready to land.Apr 27 2021, 9:22 PM
gysit updated this revision to Diff 341215.Apr 28 2021, 8:00 AM

Address comments:

  • update the documentation
  • improve doc strings
  • rename ScalarCap to ScalarCapture
  • restrict constants to floats and ints
  • use Any instead of object
gysit marked 4 inline comments as done.Apr 28 2021, 8:16 AM

Thanks for the detailed review comments!

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
200

I was not aware of the attribute functionality (@nicolasvasilache pointed me to some tc example) and yes there is indeed some overlap. The idea of CaptureDef is to add scalar parameters to the parameter list of the operation as it is done for some Linalg structured operations (e.g. for the fill operation). Capturing values directly would make it harder to isolate the operation / turn it into a library call.

I hope that explains the idea but I am open to other solutions!

mlir/docs/Tools/LinalgOpDsl.md
131

do we need the "floating point or integer value" part ?
This should also work with vectors, right?

138

can these be generalized to all scalar and vector types ? (for a subsequent commit).

One of the ideas with such custom ops was to be able to just program with vectors if we wanted.

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
200

-> "Defines an SSA value captured by the operation" ?

202

-> "The captured SSA values are not indexed by the indexing_maps of the structured op (as opposed to memrefs and tensors)" ?

338

I'd just go for "scalar and vector constants".

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
32

I'd just spell "captures" everywhere, "caps" makes me think of lower/upper.
Since this is python I wouldn't be super surprised to see caps changing code around :)

39

defensive type programming, how fancy ! :)

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
73

Not for this CL but I am wondering whether we should have these at all (vs a mechanism that just allows building all scalar + vector types) ?

mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
35

very nice looking!
Can we spell out the constant into a named lhs ? Is it an epsilon or something else?

nicolasvasilache accepted this revision.Apr 28 2021, 8:48 AM

Nice progress, thanks @gysit !

gysit updated this revision to Diff 341242.Apr 28 2021, 9:36 AM
gysit marked 8 inline comments as done.

Address comments:

  • Rename caps to captures
  • Update documentation
  • Cleaner fill_rng definition
gysit added inline comments.Apr 28 2021, 9:39 AM
mlir/docs/Tools/LinalgOpDsl.md
131

At the moment we are limited to floats and integers since we check this on the python side and since we only have code to emit Integer and Float Attributes. Extending this is possible.

138

Yes this is possible but requires a slightly different design I believe. At the moment I32 etc are simply names that carry no further information. I think we should change that as you suggested to provide a mechanism to generate different vector types etc.

mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
35

it is the inverse of the maximal random number / max_int_32

gysit added inline comments.Apr 28 2021, 9:44 AM
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
338

I will change this once we support vector constants and types.

mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
73

Yes, I almost commented the same. In some prior versions of this code, I had a more evolved type system but it was incomplete and I stripped it in favor of adding complexity later. I'm fine with this is a stepping stone but it would be good to figure out to where.