Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The contiguous layout and exposing memref indexing #19851

Open
3 tasks
krzysz00 opened this issue Jan 29, 2025 · 4 comments
Open
3 tasks

The contiguous layout and exposing memref indexing #19851

krzysz00 opened this issue Jan 29, 2025 · 4 comments

Comments

@krzysz00
Copy link
Contributor

krzysz00 commented Jan 29, 2025

The problem

The current suite of memref layouts has a gap that makes optimizations difficult: there's no good way to encode "this has the identity layout but carries an offset around".

For example, if I

%base, %sizes:2, %strides:2, %offset = %memref : memref<MxNxi32> to ...

, MLIR will instantly constant-fold %offset to 0, because this memref has the identity layout, which, by definition, doesn't apply an offset.

This assumption prevents allocation coalescing. If you do a rewrite that, say, merges

%a = memref.alloc(...) : memref<M x i32>
%b = memref.alloc(...) : memref<N x i32>

into

%ab = memref.alloc(...) : memref<M + N x i32>

, a bunch of code may already have already assumed that %b has no offset, and there's no way to find the places that that assumption was made in order to add the new offset back in.

On the other hand, using the strided layout is too weak.

Specifically, if I have a memref<? x ? x i32> we know that that's still in the identity row-major layout. However, if I want to add an offset, I need memref<? x ? x i32, strided<[?, 1], offset: ?> ... but here' that ? in the strides could be anything. A strided layout doesn't make promises about row-major contiguous indexing.

Those promises are important for optimizations - for example, they let you use affine.linearize_index to compute an index, and there're several places in the vector lowering where isIdentity() is really used for the property that we're in a contiguous layout.

The (proposed) plan: The contiguous layout

There will be one more built-in memref layout, contiguous<N, [offset: O]> where N is an integer and O is either a static value or ?.

This layout, as an affine map, has the form, (d0, d1, ..., dN) -> (d0, d1, ..., dN + s0), where s0 is the offset (and so might be a constant.

That is, it's exactly what we want: the row-major identity layout with an optional offset. That way, things that need to pattern match on "this is contiguous" can do so, and we can avoid the overly-strict implications of the identity layout.

(Note also, for instance, that memref.reshape really just needs a contigous layout, not the identity, for example)

Radical extensions

We could reasonably change the default layout from the identity AffineMap to contiguous<[rank]>, aka contiguous<[rank], offset: 0>. This has the advantage of making it clearer what's going on and having a simpler representation for the default case.

Open questions: "normalized" memrefs

I need to do more research about what a "normalized" memref is in the context of the MLIR standard library in order to understand whether or not contiguous<N, offset: ?> fits the definition.

Memref index exposure

The current lowering patterns for memref.load, for example, generate a lot of non-trivial arithmetic (multiplying indices by strides, etc.)

We want this arithmetic to be visible to MLIR pre-lowering, so that passes like -int-range-optimizations or integer narrowing can run on it, for example.

The 0-D form (on reflection, I don't think we want this), retained for history

To that end, we propose an operation memref.add_offset and a pass -memref-expand-indexing.

%q = memref.add_offset %p, %i : memref<contiguous<N, offset: ?> is the operation that takes a memref %p and returns one where offset(%q) == offset(%p) + %i, where %i is an index. For the LLVM backend, this could be getelementptr, while for the SPIR-V one, all these offsets can be accumulated into the index to an access chain.

Then, -memref-expand-indexing will walk through a program and replace all memref<S1 x S2 x ... x Sd x T, contigous<d, offset: O> with memref<product(S_i), contigous<1, offset: ?> - or perhaps even memref<T, contiguous<0, offset: ?>>, using memref.add_offset to implement the indexing that's implicit in memref operations.

That is, if we have

%buffer = memref.alloc(%s1, %s2]) : memref<? x ? x i32, contiguous<2, offset: ?>>
%x = memref.load %buffer[%i0, %i1] : memref<? x ? x i32, contiguous<2, offset: ?>>

we'd rewrite it to something along the lines of

%buffer_alloc = memref.alloc(%s1 * %s2) : memref<?xi32, contiguous<1, offset: ?>> // Or we stick an `alloc_size` parameter on here and skip the reinterpret_cast
%buffer = memref.reinterpret_cast %buffer sizes() strides() offset(%c0) : memref<i32, contiguous<0, offset: ?>>
%idx = %i1 + %s2 * %i0 // aka affine.linearize_index [%i0, %i1] by [%s1, %s2]
%x_mem = memref.add_offset %buffer, %idx : memref<i32, contiguous<0, offset: ?>>
%x = memref.load %x_mem[] : memref<i32, contiguous<0, offset: ?>>

This exposes the indexing math used by memref and allows it to be operated on, instead of making it implicit in complex operations, which is useful late in a lowering process.

Note that this does require mid-level vector ops, like vector.transfer_read, to be lowered to vector.load in order to be useful.

The 1-D form

To that end, we propose the pass -memref-linearize-contigous-indexing

We still might want memref.add_offset, just as syntactic sugar for "get the old offset, add something to it, reinterpret_cast".

With the 1-D form, we're still doing the linearization as above, but we'll want to keep the size on your memrefs.

That is, the example with the alloc above would be (assuming we allow allocs to have non-trivial offsets)

%buffer = memref.alloc(%s1 * %s2) : memref<?xi32, contiguous<1, offset: ?>>
%idx = %i1 + %s2 * %i0 // aka affine.linearize_index [%i0, %i1] by [%s1, %s2]
%x = memref.load %buffer[%idx] : memref<? x i32, contiguous<1, offset: ?>>

Prior art

The narrow type emulation patterns, as well as some lowerings to SPIR-V, sort of have patterns for this, but they're somewhat hampered by the lack of a contiguous layout.

The narrow type emulation pass could by substantially simplified if the linearization were a separate step, leaving narrow type emulation as "convert all the memref<i[small], contiguous<offset: ?>> to memref<i8, contiguous<offset: ?>> and dividing all the arguments to memref.add_offset by the relevant scale factor (+ taking remainders to extract sub-byte values).

One could even - for the LLVM lowering in particular - introduce a -llvmir-byteify-memrefs pass, which converts all memrefs to i8 memrefs before lowering, emitting a mlir.llvm.sizeof operation if it's needed. That would both expose the remaining bit of indexing math to MLIR optimizations and would cause MLIR to stop blocking the ptradd transition in LLVM - we've got some pretty fundamental usage of getelementptr [not i8] in our lowering. Obviously other targets, like SPIR-V, which need the type information, wouldn't run this pass.

Concrete steps

  • Define the contigous layout attribute, add it to parsing and printing
  • Review everything that calls getlayout() or isIdentity(), or calls out the StridedLayoutAttr by name in order to give it a new case for contiguous
  • Implement the memref index exposure pass
@krzysz00
Copy link
Contributor Author

Addenda:

  1. On reflection, we should be going to 1D memrefs here - the size of the memref is an important detail in some cases, and makes memref.alloc() simpler to deal with. Then, we don't need add_offset as much.
  2. @Hardcode84 wants memref<T, contigous<0, offset: ?>> to be something like !ptr.fatptr<T>

@MaheshRavishankar
Copy link
Contributor

Instead of just saying contiguous<N, can you instead take the list of dimensions that are contiguous (and in that order). This way you can encode column-major/row-major and all combinations thereof.

@krzysz00
Copy link
Contributor Author

@MaheshRavishankar I think that'll get me most of what I want? The main catch being that there's a lot of code that wants monotonically-increasing strides from right to left - that is, a lot of code that's asking the question isIdentity() is really trying to ask for contiguous row-major ... so we'd still want to do a lot of checking for contigous<[0, 1, 2, ..., N-1], offset: ?>.

@MaheshRavishankar
Copy link
Contributor

Yeah you can have a helper method on the contiguous attribute that is like isContiguousRowMajor to process that as a factory method. But I think that covers the space of all "interesting" ways of representing strides.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants